背景
numpy
中的reshape
函数蕴含一个order
变量,默认order='C'
,即在变形中以后面的维度(行)为优先程序重新排列元素,而order='F'
时以前面的维度(列)为优先程序重新排列元素,官网文档中给出了示例:
>>> np.reshape(a, (2, 3)) # C-like index orderingarray([[0, 1, 2], [3, 4, 5]])>>> np.reshape(np.ravel(a), (2, 3)) # equivalent to C ravel then C reshapearray([[0, 1, 2], [3, 4, 5]])>>> np.reshape(a, (2, 3), order='F') # Fortran-like index orderingarray([[0, 4, 3], [2, 1, 5]])>>> np.reshape(np.ravel(a, order='F'), (2, 3), order='F')array([[0, 4, 3], [2, 1, 5]])
pytorch解决方案
在pytorch
中,torch.reshape()
函数只承受矩阵和形态两个参数,采纳了行优先(C-Style)的变换形式,如果须要应用列优先的变换,须要借助permute()
函数,stackoverflow上给出了解决方案:
def reshape_fortran(x, shape): if len(x.shape) > 0: x = x.permute(*reversed(range(len(x.shape)))) return x.reshape(*reversed(shape)).permute(*reversed(range(len(shape))))
性能测试
然而下面的作者狐疑permute()
函数外部依然会创立张量的正本,影响效率。因而笔者对这种办法做了测试,并与numpy
的内置函数做了比照。测试环境为i9-10900X/RTX2080Ti。
测试代码:
import numpy as npimport torchimport timedim1 = 40dim2 = 50dim3 = 5def reshape_fortran(x, shape): if len(x.shape) > 0: x = x.permute(*reversed(range(len(x.shape)))) return x.reshape(*reversed(shape)).permute(*reversed(range(len(shape))))torch.cuda.set_device(0)device = torch.device('cuda')x = [torch.from_numpy(np.random.rand(dim1, dim2)).to(device) for _ in range(100)]xx = [torch.from_numpy(np.random.rand(dim1, dim2)).to(device) for _ in range(100)]for i in range(100): y = x[i].reshape([dim2, dim1])# c reshapet0 = time.time()for i in range(100): y = xx[i].reshape([dim2, dim3, -1])t1 = time.time()# fortran reshapefor i in range(100): yy = reshape_fortran(xx[i], [dim2, dim3, -1])t2 = time.time()print(f'torch build-in reshape: {(t1 - t0)/100} s')print(f'torch permute reshape: {(t2 - t1)/100} s')x = [np.random.rand(dim1, dim2) for _ in range(100)]xx = [np.random.rand(dim1, dim2) for _ in range(100)]for i in range(100): y = x[i].reshape([dim2, dim3, -1])t0 = time.time()for i in range(100): yy = xx[i].reshape([dim2, dim3, -1])t1 = time.time()for i in range(100): yyy = xx[i].reshape([dim2, dim3, -1], order='F')t2 = time.time()print(f'numpy C reshape: {(t1 - t0)/100} s')print(f'numpy F reshape: {(t2 - t1)/100} s')
测试后果:
torch build-in reshape: 9.72747802734375e-07 storch permute reshape: 1.1897087097167968e-05 snumpy C reshape: 3.0517578125e-07 snumpy F reshape: 2.474784851074219e-06 s
测试中pytorch
中基于permute()
的办法的耗时是内置行优先reshape()
函数的10倍,然而在numpy
的测试中,列优先变换的耗时也是行优先的10倍。因而能够认为在pytorch
中,基于permute()
函数的变换计算效率很高,不须要持续优化。
参考文献
numpy文档
stackoverflow原问题