乐趣区

关于人工智能:适配PyTorch-FXOneFlow让量化感知训练更简单

作者 | 刘耀辉
审稿 | BBuf、许啸宇

1

背景

近年来,量化感知训练是一个较为热点的问题,能够大大优化量化后训练造成精度损失的问题,使得训练过程更加高效。

Torch.fx 在这一问题上走在了前列,应用纯 Python 语言实现了对于 Torch.nn.Module 的解析和向 IR 的转换,也能够提供变换后的 IR 对应的 Python 代码,在内部则是提供了简洁易用的 API,大大不便了量化感知训练过程的搭建。此外,Torch.fx 也有助于打消动态图和动态图之间的 Gap,能够比拟不便地对图进行操作以及进行算子交融。

OneFlow 紧随其后增加了针对 OneFlow 的 fx,即 One-fx,在装置 One-fx 之后,用户能够间接调用 oneflow.fx,也能够间接通过import onefx as fx 进行应用。

one-fx 地址:
https://github.com/Oneflow-Inc/one-fx

One-fx 实现代码中绝大部分是对于 Torch.fx 的 fork,但依据 OneFlow 和 PyTorch 之间存在的差异进行了一些适配或优化。本文将围绕 One-fx 适配形式以及在 OneFlow 中的利用开展

2

FX 次要模块

  • Symbolioc Trace
  • Graph Module
  • Interpreter
  • Proxy
  • Passes 

其中,前 4 个模块独特实现了 fx 的基本功能,Graph Module 和 Proxy 又是 Symbolic Trace 的根底,Passes 则是在此基础上的裁减。

Symbolic Trace 的基本概念如上图所示,最根本的模型运行过程就是从模型定义到模型执行这样一个流程。

fx 则是进行了非侵入式的解析,将模型执行过程转成一张图,这张图中蕴含了很多个 Node,每一个 Node 都蕴含了模型中的子模块或者函数调用信息,而后用户能够很不便地获取到所有的 Node,并对其进行一些变换操作,最初通过 GraphModule 从新生成一个模型定义,并对其执行。

其中,在进行模型解析的时候,节点之间变量传递也均应用代理后的变量,如 y = oneflow.relu(x),实际上 x 和 y 是Proxy(x)Proxy(y)

3

One-fx 实现形式

这里给出一个 Fx 最简略的用例,以不便后续对于实现形式的介绍。

import oneflow

class MyModule(oneflow.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = oneflow.nn.Linear(512, 512)

    def forward(self, x):
        x = self.linear(x)
        y = oneflow.ones([2, 3])

        x = oneflow.relu(x)
        return y

m = MyModule()

traced = oneflow.fx.symbolic_trace(m)
print(traced.code)
"""
def forward(self, x):
    linear = self.linear(x);  x = None
    relu = oneflow.relu(linear);  linear = None
    _tensor_constant0 = self._tensor_constant0
    return _tensor_constant0
"""

函数代理

代理,即 fx 中的 Proxy 模块,目标是在每次进行函数或模块调用的时候增加一些额定操作,使得对模型的解析和重建得以进行,而包装则是适配代理的一种形式。

torch.fx 中,对于 nn.Module 的包装比拟易于了解,每当待解析 Module 中呈现了继承自 nn.Module 的对象,那么就将其 __call__ 函数替换成包装过的函数。然而,对于 pytorch 的函数的代理的实现要更“绕”一些,是借助了 __torch_function__ 这一机制
(https://github.com/pytorch/pytorch/blob/c7c723897658eda6298bb74d92e4bb18ab4a5fe3/torch/overrides.py),限于篇幅起因这里不专门对其进行介绍。比拟要害的点是,OneFlow 中没有这一机制,如果须要增加,那么会是规模很大的、侵入性的,于是 One-fx 的实现就须要找其它门路。

咱们应用的解决形式是搜寻 oneflowoneflow.nn.functionaloneflow._C 等模块中的 Callable,并去除其中属于类的局部,而后对其余函数进行包装,在每次解析模型之前,会将这些模块的 __dict__ 中对应项替换成包装后的函数,并且在解析模型之后从新将这些项进行还原。对于 constructor 类型的函数,如 ones,randn 等则不进行代理,间接运行,在最终构建图的时候作为 constant 来解决。

对于函数的包装局部源码实现如下,每次运行代理后的函数,会先判断该函数的入参中有没有 Proxy 变量,如果有,那么将会创立一个 call_function 类型的节点并返回 Proxy 包装后的节点,否则间接调用原函数并返回后果。

def _create_wrapped_func(orig_fn):
    @functools.wraps(orig_fn)
    def wrapped(*args, **kwargs):
        # 判断参数中是否存在 proxy 变量
        proxy = _find_proxy(args, kwargs)
        if proxy is not None:
            # 如果参数中有 Proxy 变量,创立节点并返回 Proxy 包装后的节点
            return_proxy = proxy.tracer.create_proxy("call_function", orig_fn, args, kwargs)
            return_proxy.node.meta["is_wrapped"] = True
            return return_proxy
        # 如果没有 Proxy 变量,间接调用原函数
        return orig_fn(*args, **kwargs)

    return wrapped

其中,return_proxy = proxy.tracer.create_proxy("call_function", orig_fn, args, kwargs) 这行代码指定了应用与入参雷同的 Tracer 来创立节点并返回后果,create_proxy 函数定义的次要局部如下,创立节点并在 Proxy 包装后返回。

def create_proxy(self, kind: str, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any],
                     name: Optional[str] = None, type_expr : Optional[Any] = None,
                     proxy_factory_fn: Callable[[Node], 'Proxy'] = None):
    args_ = self.create_arg(args)
    kwargs_ = self.create_arg(kwargs)
    assert isinstance(args_, tuple)
    assert isinstance(kwargs_, dict)

    # 创立节点
    node = self.create_node(kind, target, args_, kwargs_, name, type_expr)

    if not proxy_factory_fn:
        proxy = self.proxy(node)
    else:
        proxy = proxy_factory_fn(node)

    return proxy

而其中的 create_node 办法,实际上是调用了Tracer.graph.create_node,在图中创立节点,次要局部代码如下,其中 op 就是 fx IR 中的 op,代表了节点类型,而 target 则是节点的操作主体,在下面的例子中就是orig_func

因而,当咱们自定义的 Module 中的 forward 函数中的所有调用都被包装之后,实际上再运行 forward 的时候,就会顺次在 Tracer.graph 中创立节点,这也正是 symbolic_trace 的基本思路。

def create_node(self, op: str, target: 'Target',
                    args: Optional[Tuple['Argument', ...]] = None,
                    kwargs: Optional[Dict[str, 'Argument']] = None,
                    name: Optional[str] = None,
                    type_expr: Optional[Any] = None) -> Node:
    # 此处有一些 assert

    # 创立一个节点名称,防止反复
    candidate = name if name is not None else self._target_to_str(target)
    name = self._graph_namespace.create_name(candidate, None)
    # 创立节点
    n = Node(self, name, op, target, args, kwargs, type_expr)

    # 建设名称与节点的映射关系
    self._graph_namespace.associate_name_with_obj(name, n)

    return n

而对于 symbolic_trace 过程,其外围就是Tracer.trace。这个办法能够分为两局部,一个是预处理局部,一个是骨干局部。其中预处理过程大抵定义如下,次要工作是初始化 Graph、确立模型以及 forward 函数和创立包装后的参数。

如后面所提及的,symbolic trace 的基本思路是借助 Proxy 变量以及包装后的函数,在每次调用的时候都创立一个节点,因而,forward 函数的输出也须要用 Proxy 进行包装,这一步定义在 Tracer.create_args_for_root 中。

def trace(
        self,
        root: Union[oneflow.nn.Module, Callable[..., Any]],
        concrete_args: Optional[Dict[str, Any]] = None,
    ) -> Graph:
    # 确定模块主体以及 forward 函数,其中 fn 即 forward 函数
    if isinstance(root, oneflow.nn.Module):
        self.root = root

        assert hasattr(type(root), self.traced_func_name
        ), f"traced_func_name={self.traced_func_name} doesn't exist in {type(root).__name__}"

        fn = getattr(type(root), self.traced_func_name)
        self.submodule_paths = {mod: name for name, mod in root.named_modules()}
    else:
        self.root = oneflow.nn.Module()
        fn = root

    tracer_cls: Optional[Type["Tracer"]] = getattr(self, "__class__", None)
    # 在 Tracer 中初始化一张图
    self.graph = Graph(tracer_cls=tracer_cls)
    
    self.tensor_attrs: Dict[oneflow.Tensor, str] = {}
    # 这个子函数用于收集模型中所有 Tensor 类型的变量
    def collect_tensor_attrs(m: oneflow.nn.Module, prefix_atoms: List[str]):
        for k, v in m.__dict__.items():
            if isinstance(v, oneflow.Tensor):
                self.tensor_attrs[v] = ".".join(prefix_atoms + [k])
        for k, v in m.named_children():
            collect_tensor_attrs(v, prefix_atoms + [k])

    collect_tensor_attrs(self.root, [])

    assert isinstance(fn, FunctionType)

    # 获取 fn 所在模块的所有可读变量
    fn_globals = fn.__globals__
    # 创立包装后的参数
    fn, args = self.create_args_for_root(fn, isinstance(root, oneflow.nn.Module), concrete_args
    )

随后则是 trace 的骨干局部,这一部分大抵代码如下,次要工作是对函数、办法、模块进行必要的包装,而后在 Graph 中创立节点,实现整个图的信息。

其中,咱们会创立一个 Patcher 环境并在其中进行这些过程,这是因为对于函数和办法的包装会间接扭转掉某些包中对应函数或办法的行为,为了不让这种行为的扭转溢出到 trace 的范畴之外,在每次进行包装的时候会在 Patcher 中记录本次操作,而后在 _Patcher.__exit__ 中依据记录的操作一一还原现场。

# 上面代码依然是 `trace` 函数的一部分

# 定义对于 `nn.Module` 的 getattr 办法的包装
@functools.wraps(_orig_module_getattr)
def module_getattr_wrapper(mod, attr):
    attr_val = _orig_module_getattr(mod, attr)
    return self.getattr(attr, attr_val, parameter_proxy_cache)

# 定义对于 `nn.Module` 的 forward 办法的包装
@functools.wraps(_orig_module_call)
def module_call_wrapper(mod, *args, **kwargs):
    def forward(*args, **kwargs):
        return _orig_module_call(mod, *args, **kwargs)

    _autowrap_check(
        patcher,
        getattr(getattr(mod, "forward", mod), "__globals__", {}),
        self._autowrap_function_ids,
    )
    return self.call_module(mod, forward, args, kwargs)
# 这里 Patcher 的作用是在退出这一环境的时候复原现场,防止包装函数、办法的影响溢出到 `trace` 之外。with _Patcher() as patcher:
    # 对 `__getattr__` 和 `nn.Module.__call__` 这两个办法默认进行包装
    patcher.patch_method(
        oneflow.nn.Module,
        "__getattr__",
        module_getattr_wrapper,
        deduplicate=False,
    )
    patcher.patch_method(oneflow.nn.Module, "__call__", module_call_wrapper, deduplicate=False)
    # 对预约好须要进行包装的函数进行包装
    _patch_wrapped_functions(patcher)
    _autowrap_check(patcher, fn_globals, self._autowrap_function_ids)
    # 遍历所有须要对其中函数进行主动包装的 package
    for module in self._autowrap_search:
        if module is oneflow:
            dict = {}
            # 当 package 为 oneflow 时,对此进行非凡解决,独自分出一个字典寄存本来 `oneflow.__dict__` 中的内容
            for name, value in module.__dict__.items():
                if not isinstance(value, oneflow.nn.Module) and not value in _oneflow_no_wrapped_functions:
                    dict[name] = value
            _autowrap_check_oneflow(patcher, dict, module.__dict__, self._autowrap_function_ids)
        else:
            _autowrap_check(patcher, module.__dict__, self._autowrap_function_ids)
    # 创立节点,这里的 `create_node` 调用实际上只是创立了最初一个节点,即输入节点。# 然而这里 `fn` 就是 forward 函数,在运行这一函数的时候,就会如后面所说顺次创立节点。self.create_node(
        "output",
        "output",
        (self.create_arg(fn(*args)),),
        {},
        type_expr=fn.__annotations__.get("return", None),
    )

其中,_patch_wrapped_functions的实现如下:

def _patch_wrapped_functions(patcher: _Patcher):
    # `_wrapped_fns_to_patch` 中蕴含了所有须要主动包装的函数
    for frame_dict, name in _wrapped_fns_to_patch:
        if name not in frame_dict:
            if hasattr(builtins, name):
                # 对于 built-in 函数,不存在于 frame_dict 中,独自进行解决来依据名称获取函数自身
                orig_fn = getattr(builtins, name)
            else:
                # 如果是 oneflow 中指定须要包装的函数,那么就进行获取,否则抛出名称无奈辨认的异样
                is_oneflow_wrapped_function, func = is_oneflow_wrapped_function_and_try_get(name)
                if is_oneflow_wrapped_function:
                    orig_fn = func
                else:
                    raise NameError("Cannot deal with the function %s."%name)
        else:
            # 如果函数名称曾经存在于 frame_dict 中,间接通过字典查问来取得函数
            orig_fn = frame_dict[name]
        # 创立包装后的函数并进行 `patch`,即定义当 trace 过程完结的时候,如何还原现场
        patcher.patch(frame_dict, name, _create_wrapped_func(orig_fn))
    
    # 对于类中的办法,间接包装并 patch。for cls, name in _wrapped_methods_to_patch:
        patcher.patch_method(cls, name, _create_wrapped_method(cls, name))

全局包装

在模型的 forward 函数中,咱们有时不仅会用到框架自带的模块或者函数,有点时候还须要用到自定义的函数或者 built-in 函数,对于这种状况如果不进行解决,那么天然无奈承受 Proxy(x) 的入参。fx 中提供了 fx.wrap 这一 API,当用户须要调用这部分函数的时候,能够实现应用 fx.wrap(func) 使其被包装。

例如:

import oneflow

oneflow.fx.wrap(len)
class MyModule(oneflow.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = oneflow.nn.Linear(512, 512)

    def forward(self, x):
        x = self.linear(x) + len(x.shape)
        return x

traced = oneflow.fx.symbolic_trace(MyModule())
print(traced.code)
"""
def forward(self, x):
    linear = self.linear(x)
    getattr_1 = x.shape;  x = None
    len_1 = len(getattr_1);  getattr_1 = None
    add = linear + len_1;  linear = len_1 = None
    return add
"""

然而其局限性在于,如果 Module 的源代码是来自其它库,那么在调用的中央应用 fx.wrap 是不起作用的,在 oneflow 和 torch 中都会有这一问题。然而 flowvision 中有多处应用了 built-in function,因而咱们增加了一个 API,即 global_wrap,原理比较简单,就是间接对某个函数所在的包的__dict__ 进行批改,用法如下:

# MyModule 来自其它包
with oneflow.fx.global_wrap(len):
    m = MyModule()

    traced = oneflow.fx.symbolic_trace(m)
    print(traced.code)
    """
    def forward(self, x):
        linear = self.linear(x);  x = None
        getattr_1 = linear.shape
        len_1 = len(getattr_1);  getattr_1 = None
        relu = oneflow.relu(linear);  linear = None
        add = relu + len_1;  relu = len_1 = None
        return add
    """

应用 with 关键字的起因是这种实现形式是间接批改了某个包的__dict__,对于其它中央的调用也会产生影响,因而须要将其限度在肯定范畴内。此外,包装后的函数蕴含了对类型的断定等一系列操作,也会极大影响 built-in 函数的性能。

其它适配

其它中央的解决都比较简单,不须要对实现形式做批改,只须要将细节局部对齐即可,这也体现出 oneflow 和 pytorch 在前端局部的高度兼容性。

4

IR 设计

fx 的 IR 设计遵循以下几个准则:

  • 防止反对长尾散布,简单的样例。次要关注经典模型的程序捕捉和变换。
  • 应用机器学习从业者曾经相熟的工具和概念,例如 Python 的数据结构和 PyTorch 中公开记录的算子。
  • 使程序捕捉过程具备高度可配置性,以便用户能够为长尾需要实现本人的解决方案。

fx 的 IR 次要由几个局部组成;

  • opcode:即以后操作的类型,能够是 placeholder, get_attr, call_function, call_method, call_module, output
  • name:即给以后操作的命名。
  • target:以后操作的实体,例如对于 call_function 类型的操作,可能这一属性会是 <built-in function len>。
  • args 和 kwargs:指定以后操作的参数。

通过 print_tabular 这一 API 能够很不便好看地打印出 fx 中的 IR,例如对于以下的 MyModule 模型,咱们能够打印出其 IR:

import oneflow

class MyModule(oneflow.nn.Module):
    def __init__(self, do_activation : bool = False):
        super().__init__()
        self.do_activation = do_activation
        self.linear = oneflow.nn.Linear(512, 512)

    def forward(self, x):
        x = self.linear(x)
        y = oneflow.ones([2, 3])

        x = oneflow.topk(x, 10)
        return x.relu() + y

traced = oneflow.fx.symbolic_trace(MyModule())
traced.graph.print_tabular()

"""
opcode         name               target                    args                       kwargs
-------------  -----------------  ------------------------  -------------------------  --------
placeholder    x                  x                         ()                         {}
call_module    linear             linear                    (x,)                       {}
call_function  topk               <built-in function topk>  (linear, 10)               {}
call_method    relu               relu                      (topk,)                    {}
get_attr       _tensor_constant0  _tensor_constant0         ()                         {}
call_function  add                <built-in function add>   (relu, _tensor_constant0)  {}
output         output             output                    (add,)                     {}
"""

只管 fx 的 IR 不算弱小(例如不能解决动态控制流),然而定义十分简洁,实现简略,对于用户来讲上手门槛绝对低很多。

5

One-fx 利用举例

OP 替换

上面的例子展现了如何将 add 操作全副替换成 mul 操作。

import oneflow
from oneflow.fx import symbolic_trace
import operator

class M(oneflow.nn.Module):
    def forward(self, x, y):
        return x + y, oneflow.add(x, y), x.add(y)

if __name__ == '__main__':
    traced = symbolic_trace(M())

    patterns = set([operator.add, oneflow.add, "add"])

    for n in traced.graph.nodes:
        if any(n.target == pattern for pattern in patterns):
            with traced.graph.inserting_after(n):
                new_node = traced.graph.call_function(oneflow.mul, n.args, n.kwargs)
                n.replace_all_uses_with(new_node)
            traced.graph.erase_node(n)

    traced.recompile()

    traced.graph.print_tabular()

    print(traced.code)

性能剖析

以下代码展现如何应用 fx 进行模型的性能剖析,将本来的模型通过 symbolic_trace 解析成各个节点,再在其中插入测试性能的操作。

import oneflow
import flowvision.models as models
import statistics, tabulate, time
from typing import Any, Dict, List

class ProfilingInterpreter(oneflow.fx.Interpreter):
    def __init__(self, mod : oneflow.nn.Module):
        gm = oneflow.fx.symbolic_trace(mod)
        super().__init__(gm)

        # 记录总运行工夫
        self.total_runtime_sec : List[float] = []
        # 记录各个节点运行工夫
        self.runtimes_sec : Dict[oneflow.fx.Node, List[float]] = {}

    # 重写 `run` 办法,实质上是对基类 `run` 办法的简略封装,在运行前后记录时间点。# 这一办法是 Graph 整体运行的入口。def run(self, *args) -> Any:
        t_start = time.time()
        return_val = super().run(*args)
        t_end = time.time()
        self.total_runtime_sec.append(t_end - t_start)
        return return_val

    # 同上,重写 `run_node` 办法,不须要本人写细节实现,只须要在对基类的 `run_node` 调用前后记录时间点即可
    # 这一办法是 Graph 中运行每个 Node 的入口。def run_node(self, n : oneflow.fx.Node) -> Any:
        t_start = time.time()
        return_val = super().run_node(n)
        t_end = time.time()
        self.runtimes_sec.setdefault(n, [])
        self.runtimes_sec[n].append(t_end - t_start)
        return return_val

    # 定义如何打印性能测试后果
    def summary(self, should_sort : bool = False) -> str:
        # 存储每个节点的打印信息
        node_summaries : List[List[Any]] = []
        # 因为模块会被调用屡次,所以这里计算一下均匀的运行总时长
        mean_total_runtime = statistics.mean(self.total_runtime_sec)

        for node, runtimes in self.runtimes_sec.items():
            mean_runtime = statistics.mean(runtimes)
            # 计算节点运行工夫占总工夫的比例
            pct_total = mean_runtime / mean_total_runtime * 100
            # 记录节点信息、节点均匀运行时长和节点运行工夫占总工夫的比例
            node_summaries.append([node.op, str(node), mean_runtime, pct_total])

        # 如果须要,安依照运行工夫进行排序
        if should_sort:
            node_summaries.sort(key=lambda s: s[2], reverse=True)

        # 以下是借助 tabulate 库进行格式化来丑化显示成果
        headers : List[str] = ['Op type', 'Op', 'Average runtime (s)', 'Pct total runtime'
        ]
        return tabulate.tabulate(node_summaries, headers=headers)


if __name__ == '__main__':
    rn18 = models.resnet18()
    rn18.eval()
    input = oneflow.randn(5, 3, 224, 224)
    output = rn18(input)
    interp = ProfilingInterpreter(rn18)
    interp.run(input)
    print(interp.summary(True))

成果如下:

算子交融

以下代码演示如何借助 fx 将模型中的卷积层和 BN 层进行交融,对于这种组合,并不需要引入新的算子,只须要对本来 conv 的权重进行操作即可。能够参考:https://nenadmarkus.com/p/fusing-batchnorm-and-conv/。

import sys
import oneflow
import oneflow.nn as nn
import numpy as np
import copy
from typing import Dict, Any, Tuple

# 通过间接对权重进行运算的形式进行 Conv 和 BN 的交融
def fuse_conv_bn_eval(conv, bn):
    assert(not (conv.training or bn.training)), "Fusion only for eval!"
    fused_conv = copy.deepcopy(conv)

    fused_conv.weight, fused_conv.bias = \
        fuse_conv_bn_weights(fused_conv.weight, fused_conv.bias,
                             bn.running_mean, bn.running_var, bn.eps, bn.weight, bn.bias)

    return fused_conv

# 权重交融形式
def fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b):
    if conv_b is None:
        conv_b = oneflow.zeros_like(bn_rm)
    if bn_w is None:
        bn_w = oneflow.ones_like(bn_rm)
    if bn_b is None:
        bn_b = oneflow.zeros_like(bn_rm)
    bn_var_rsqrt = oneflow.rsqrt(bn_rv + bn_eps)

    conv_w = conv_w * (bn_w * bn_var_rsqrt).reshape([-1] + [1] * (len(conv_w.shape) - 1))
    conv_b = (conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b

    return oneflow.nn.Parameter(conv_w), oneflow.nn.Parameter(conv_b)

# 依据字符串对名称进行宰割,比方 `foo.bar.baz` -> (`foo.bar`, `baz`)
def _parent_name(target : str) -> Tuple[str, str]:
    *parent, name = target.rsplit('.', 1)
    return parent[0] if parent else '', name

def replace_node_module(node: oneflow.fx.Node, modules: Dict[str, Any], new_module: oneflow.nn.Module):
    assert(isinstance(node.target, str))
    parent_name, name = _parent_name(node.target)
    setattr(modules[parent_name], name, new_module)

# 定义对模型进行交融操作的过程
def fuse(model: oneflow.nn.Module) -> oneflow.nn.Module:
    model = copy.deepcopy(model)
    # 先通过 fx.symbolic_trace 获取一个 GraphModule
    fx_model: oneflow.fx.GraphModule = oneflow.fx.symbolic_trace(model)
    modules = dict(fx_model.named_modules())

    # 遍历 GraphModule 中的所有节点,别离进行操作
    for node in fx_model.graph.nodes:
        # 跳过所有不是 module 的节点
        if node.op != 'call_module':
            continue
        # 检测到 conv+bn 的构造后进行交融操作
        if type(modules[node.target]) is nn.BatchNorm2d and type(modules[node.args[0].target]) is nn.Conv2d:
            # conv 的输入同时被其它节点应用,即 conv 后连贯两个节点时无奈交融
            if len(node.args[0].users) > 1:
                continue
            conv = modules[node.args[0].target]
            bn = modules[node.target]
            fused_conv = fuse_conv_bn_eval(conv, bn)
            replace_node_module(node.args[0], modules, fused_conv)
            # 对图中的边进行置换,对于用到 bn 输入的节点,要更改它们的输出
            node.replace_all_uses_with(node.args[0])
            # 移除旧的节点
            fx_model.graph.erase_node(node)
    fx_model.graph.lint()
    # 从新建图(结构模型)fx_model.recompile()
    return fx_model


if __name__ == '__main__':
    # 以下引入 flowvision 中的 resnet 18 模型,并进行交融前后的 benchmark 比拟
    import flowvision.models as models
    import time

    rn18 = models.resnet18().cuda()
    rn18.eval()

    inp = oneflow.randn(10, 3, 224, 224).cuda()
    output = rn18(inp)

    def benchmark(model, iters=20):
        for _ in range(10):
            model(inp)
        oneflow.cuda.synchronize()
        begin = time.time()
        for _ in range(iters):
            model(inp)
        return str(time.time()-begin)

    fused_rn18 = fuse(rn18)
    unfused_time = benchmark(rn18)
    fused_time = benchmark(fused_rn18)
    print("Unfused time:", benchmark(rn18))
    print("Fused time:", benchmark(fused_rn18))
    assert unfused_time > fused_time

6

将来打算

  • 基于 fx 进行 8bit 量化感知训练和部署
  • 基于 fx 进行算子交融
  • eager 模式下基于 fx 取得模型更准确的 FLOPs 和 MACs 后果

参考文献
1.https://pytorch.org/docs/stable/fx.html
2.https://github.com/Oneflow-Inc/one-fx
3.https://pytorch.org/tutorials/intermediate/fx_conv_bn_fuser.html
4.https://pytorch.org/tutorials/intermediate/fx_profiling_tutor…
5.https://zhuanlan.zhihu.com/p/449908382

欢送 Star、试用 OneFlow 最新版本:https://github.com/Oneflow-Inc/oneflow/

退出移动版