乐趣区

关于pytorch:pytorch列优先fortranlikereshape的实现与性能

背景

numpy 中的 reshape 函数蕴含一个 order 变量,默认 order='C',即在变形中以后面的维度(行)为优先程序重新排列元素,而order='F' 时以前面的维度(列)为优先程序重新排列元素,官网文档中给出了示例:

>>> np.reshape(a, (2, 3)) # C-like index ordering
array([[0, 1, 2],
 [3, 4, 5]])
>>> np.reshape(np.ravel(a), (2, 3)) # equivalent to C ravel then C reshape
array([[0, 1, 2],
 [3, 4, 5]])
>>> np.reshape(a, (2, 3), order='F') # Fortran-like index ordering
array([[0, 4, 3],
 [2, 1, 5]])
>>> np.reshape(np.ravel(a, order='F'), (2, 3), order='F')
array([[0, 4, 3],
 [2, 1, 5]])

pytorch 解决方案

pytorch 中,torch.reshape()函数只承受矩阵和形态两个参数,采纳了行优先 (C-Style) 的变换形式,如果须要应用列优先的变换,须要借助 permute() 函数,stackoverflow 上给出了解决方案:

def reshape_fortran(x, shape):
    if len(x.shape) > 0:
        x = x.permute(*reversed(range(len(x.shape))))
    return x.reshape(*reversed(shape)).permute(*reversed(range(len(shape))))

性能测试

然而下面的作者狐疑 permute() 函数外部依然会创立张量的正本,影响效率。因而笔者对这种办法做了测试,并与 numpy 的内置函数做了比照。测试环境为 i9-10900X/RTX2080Ti。

测试代码:

import numpy as np
import torch
import time

dim1 = 40
dim2 = 50
dim3 = 5

def reshape_fortran(x, shape):
    if len(x.shape) > 0:
        x = x.permute(*reversed(range(len(x.shape))))
    return x.reshape(*reversed(shape)).permute(*reversed(range(len(shape))))

torch.cuda.set_device(0)
device = torch.device('cuda')

x = [torch.from_numpy(np.random.rand(dim1, dim2)).to(device) for _ in range(100)]
xx = [torch.from_numpy(np.random.rand(dim1, dim2)).to(device) for _ in range(100)]
for i in range(100):
    y = x[i].reshape([dim2, dim1])
# c reshape
t0 = time.time()
for i in range(100):
    y = xx[i].reshape([dim2, dim3, -1])
t1 = time.time()

# fortran reshape
for i in range(100):
    yy = reshape_fortran(xx[i], [dim2, dim3, -1])
t2 = time.time()

print(f'torch build-in reshape: {(t1 - t0)/100} s')
print(f'torch permute reshape: {(t2 - t1)/100} s')

x = [np.random.rand(dim1, dim2) for _ in range(100)]
xx = [np.random.rand(dim1, dim2) for _ in range(100)]
for i in range(100):
    y = x[i].reshape([dim2, dim3, -1])
t0 = time.time()
for i in range(100):
    yy = xx[i].reshape([dim2, dim3, -1])
t1 = time.time()
for i in range(100):
    yyy = xx[i].reshape([dim2, dim3, -1], order='F')
t2 = time.time()

print(f'numpy C reshape: {(t1 - t0)/100} s')
print(f'numpy F reshape: {(t2 - t1)/100} s')

测试后果:

torch build-in reshape: 9.72747802734375e-07 s
torch permute reshape: 1.1897087097167968e-05 s
numpy C reshape: 3.0517578125e-07 s
numpy F reshape: 2.474784851074219e-06 s

测试中 pytorch 中基于 permute() 的办法的耗时是内置行优先 reshape() 函数的 10 倍,然而在 numpy 的测试中,列优先变换的耗时也是行优先的 10 倍。因而能够认为在 pytorch 中,基于 permute() 函数的变换计算效率很高,不须要持续优化。

参考文献

numpy 文档

stackoverflow 原问题

退出移动版