关于大数据:用OneFlow实现数据类型自动提升

2次阅读

共计 6378 个字符,预计需要花费 16 分钟才能阅读完成。

一、问题引入
咱们先简略看下在 PyTorch 下的这几段代码,读者能够猜下最初输入的类型是什么:

x_tensor = torch.ones((3,), dtype=torch.int8)
y1_tensor = torch.tensor(1, dtype=torch.float64)
out1 = torch.mul(x_tensor, y1_tensor)

y2_tensor = torch.tensor(1, dtype=torch.int64)
out2 = torch.mul(x_tensor, y2_tensor)

out3 = torch.mul(x_tensor, 1.0)

out4 = torch.mul(x_tensor, 2^63-1(the max value of int64))
接下来揭晓答案:

out1.dtype: torch.float64
out2.dtype: torch.int8
out3.dtype: torch.float32
out4.dtype: torch.int8
能够察看到同样是 multiply 运算,有些后果的数据类型被晋升到更高的一级,有些并没有被晋升,还维持着 int8 类型。这其实是一种类型晋升零碎,零碎内会自定义一些类型晋升的规定,依据输出的数据类型来推导最终后果的数据类型。
**
二、Python Array API 规范 **

在这里咱们能够理解到 Python Array 的类型晋升规定

类型晋升
从上图能够看到:

不同数据类型的晋升遵循这个连贯的规定

虚线示意 python 标量在溢出的时候未定义

bool int float 之间没有连线,示意这种混合类型的晋升未定义

对于第一条,咱们能够看 int8 和 uint8,两者最终指向了 int16,示意两者运算后最终类型晋升到了 int16

而依据这一个规定,咱们能够列出一个类型晋升表格(这个表格很重要,后续看 Pytorch 源码也会用到)

以 unsigned int 系列和 signed int 系列为例,列出的表格为:

更多类型晋升规定表格可参考后面提到的链接

横坐标和纵坐标别离代表输出的数据类型,表格的值代表类型晋升后的数据类型,其中:

i1 : 8-bit signed integer (i.e., int8)

i2 : 16-bit signed integer (i.e., int16)

i4 : 32-bit signed integer (i.e., int32)

i8 : 64-bit signed integer (i.e., int64)

同理于 unsigned int

Python Array 和 Scalar 的类型晋升
上述这些都是 array 与 array 之间运算的类型晋升规定,而 array 与 scalar(就是独自一个 int,float 数值)的类型晋升规定则不一样。

如果两者同属于一个数据类型系列(比方都是 int 系列,蕴含 int8, int32, int64),则最终数据类型遵循数组的数据类型

如果两者同不属于一个数据类型系列(比方一个是 int32,一个是 float),则进行类型晋升

咱们能够看下简略的两个例子:

x_tensor = torch.ones((3,), dtype=torch.int16)
out1 = x_tensor + 2 # out.dtype = torch.int16
out2 = x_tensor + 2.0 # out.dtype = torch.float32
须要留神的是,Array 与 Scalar 的行为会和 Array 与 0d Array 的行为保持一致。

咱们能够再测试后面两个例子,不同之处是咱们将 scalar 改成一个 0d Array

x_tensor = torch.ones((3,), dtype=torch.int16)
y1_tensor = torch.tensor(2)
y2_tensor = torch.tensor(2.0)

out1 = x_tensor + y1_tensor # out.dtype = torch.int16
out2 = x_tensor + y2_tensor # out.dtype = torch.float32
对于与 Scalar 运算的行为,Pytorch 是和 Python Array API 规范统一的,然而 Numpy 则不同,他会依据 scalar 的数据范畴做一个正当的类型晋升:

import numpy as np

x = np.ones((3, 3), dtype=np.int32)
out = x + (2**31-1) # dtype: np.int32
out = x + (2**31) # dtype: np.int64
我集体更偏向于在类型晋升中,Scalar 是独自一种行为,而 Scalar Tensor 和 Tensor 的行为统一

三、其余状况

除了后面提到的规定,PyTorch 还存在以下两种状况:

要求两个输出的数据类型完全一致,如 torch.dot

RuntimeError: dot : expected both vectors to have same dtype, but found Short and Float
输出存在一个最低数据类型,比方 torch.sum,传任意 int 系列数据类型,最终输入后果均为 torch.int64。

以上就简略介绍了 Pytorch 的类型晋升规定,还想要更多的例子能够参考官网文档:

https://pytorch.org/docs/mast…

四、PyTorch 是怎么做类型晋升的?
理论运算的 Kernel,输出和输入的数据类型都是雷同的模板参数,不存在特化一个输出为 int32,输入为 float32 或其余类型的函数。

因而 PyTorch 外部会先推断出一个正当的 dtype,而后插入一个 to 这个 op,将输出 tensor 进行类型晋升,再进入到 Kernel 进行理论的运算。上面咱们会依据 PyTorch 的源码进行解说:

波及到的代码:

https://github.com/pytorch/py…

https://github.com/pytorch/py…

https://github.com/pytorch/py…

https://github.com/pytorch/py…

ScalarType.h
在这个头文件里定义了相干的数据类型,并且定义了一个类型晋升的二维矩阵,这样咱们就能够输出两个数据类型,依据索引拿到晋升后的数据类型。大数据培训

类型晋升矩阵
Activation.cpp
https://github.com/pytorch/py… 咱们以其中一个激活函数 threshold 为例子

TORCH_META_FUNC(threshold)(const Tensor& self, const Scalar& threshold, const Scalar& value) {
const Tensor& result = maybe_get_output();
build(TensorIteratorConfig()

...
.promote_inputs_to_common_dtype(true)

}
这里调用了一个 build 函数,函数承受一个 TensorIteratorConfig,这个 Config 类是用于配制各种属性,能够看到这里调用 promote_inputs_to_common_dtype 并设为 true。

TensorIterator.cpp
build 函数定义在:

https://github.com/pytorch/py…

在 1340 行,build 函数外部调用了 compute_type 函数


compute_types(config);

而该函数在 260 行开始,进行一系列类型推导

其中 TensorIterator 是一个容器类(Numpy 里也有一个相似的容器 NpyIter),用于存储输入,输出 tensor,外面用了多个 for 循环来推导失去一个 common_dtype。

并在最初进行条件判断:promote_inputs_to_common_dtype_为 true,以后 Tensor 不是输入 Tensor,且输出的 dtype 不等于推导失去的 common_dtype,则做一个类型晋升:

// Promotes inputs by creating temporaries of the correct dtype

  if (config.promote_inputs_to_common_dtype_ && !op.is_output && op.current_dtype != common_dtype_) {
    op.original_tensor = op.tensor;
    op.tensor = c10::MaybeOwned<Tensor>::owned(op.tensor->to(common_dtype_));
    op.current_dtype = common_dtype_;
    op.target_dtype = common_dtype_;
  }

五、OneFlow 的做法
相干 PR:https://github.com/Oneflow-In…

OneFlow 则将类型晋升的逻辑放在 c ++ 中 functional 前端局部,相似的咱们设计了一个 TensorProcessor 类,接口设计如下:

class TensorProcessor final {
public:
TensorProcessor()

  : common_dtype_(DType::InvalidDataType()), promote_inputs_to_common_dtype_(false){};

TensorProcessor& AddInputs(const TensorTuple& init_list);
TensorProcessor& AddInputs(const TensorTuple& init_list, Symbol<DType> tensor_lowest_dtype);

Maybe<void> Apply();
TensorProcessor& PromoteInputsToCommonDtype(bool is_promote);
Maybe<TensorTuple&> GetInputs() { return tensor_tuple_;};

private:
TensorTuple tensor_tuple_;
Symbol<DType> common_dtype_;
std::vector<Symbol<DType>> inputs_lowest_dtype_vec_;

bool promote_inputs_to_common_dtype_;
};
以二元操作 Functor 基类为例,在理论调用的时候,咱们能够这样:

class BinaryFunctor{
public:
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,

                       const std::shared_ptr<one::Tensor>& y) const {
TensorProcessor tensor_processor;
JUST(tensor_processor.PromoteInputsToCommonDtype(true).AddInputs({x, y}).Apply());
TensorTuple input_tuple = JUST(tensor_processor.GetInputs());
return OpInterpUtil::Dispatch<Tensor>(*op_, input_tuple);


}

};
PromoteInputsToCommonDtype 用于设置相干属性

AddInputs 函数将须要参加类型晋升的 Tensor 增加到容器中

Apply 函数执行理论的类型晋升等逻辑

tensor_processor.cpp 还有其余几个函数,这里简略介绍下性能:

CheckHasDifferentInputDType 遍历输出 Tensor,查看输出 Tensor 是否有不同的 dtype

ComputeCommonDType 依据输出 dtype 推导一个正当的晋升过的 dtype

CastToSameType 给输出 Tensor 插入一个 Cast 操作

Maybe<void> CastToSameType(TensorTuple& tensor_tuple, const Symbol<DType>& common_dtype) {
for (auto& tensor_ptr : tensor_tuple) {

if (tensor_ptr->dtype() != common_dtype) {tensor_ptr = JUST(functional::Cast(tensor_ptr, common_dtype));
}

}
return Maybe<void>::Ok();
}
Apply 函数逻辑如下:

Maybe<void> TensorProcessor::Apply() {
if (promote_inputs_to_common_dtype_) {

bool has_different_input_dtype = CheckHasDifferentInputDType(tensor_tuple_);
if (has_different_input_dtype) {common_dtype_ = ComputeCommonDType(tensor_tuple_);
  JUST(CastToSameType(tensor_tuple_, common_dtype_));
}

} else {

for (int i = 0; i < tensor_tuple_.size(); ++i) {
  // Cast all the inputs to it's attribute `lowest_dtype` if the input tensor dtype is lower
  // than attribute `lowest_dtype`.
  Symbol<DType> base_dtype = inputs_lowest_dtype_vec_.at(i);
  if (base_dtype->data_type()
      && DType::priority_order[base_dtype->data_type()]
             > DType::priority_order[tensor_tuple_.at(i)->dtype()->data_type()]) {tensor_tuple_.at(i) = JUST(one::functional::Cast(tensor_tuple_.at(i), base_dtype));
  }
}

}
return Maybe<void>::Ok();
}
if 内执行的是类型晋升,而 else 内逻辑则是对应后面提到的其余状况中的第二条,将 Tensor 类型晋升到设定好的一个最低数据类型。还是 sum 算子,咱们设定最低数据类型为 int64 是这么做的:

class ReduceSumFunctor{
public:
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const std::vector<int32_t>& axis,

                       const bool& keepdims) const {
...
TensorProcessor tensor_processor;
JUST(tensor_processor.AddInputs({x}, /*lowest_dtype=*/DType::Int64()).Apply());
TensorTuple input_tuple = JUST(tensor_processor.GetInputs());

}

};

总结
类型晋升是一个咱们不经意间会应用的一个操作,如果没有正确处理输入的数据类型,则可能导致后果溢出,呈现谬误的后果。看似很简略,但理论调研 + 斟酌细节也搞了两三周,最初感激共事在我实现这个性能的期间提供的许多帮忙。

正文完
 0