关于python:torchcat-速度太慢

40次阅读

共计 1114 个字符,预计需要花费 3 分钟才能阅读完成。

torch.cat 是 PyTorch 中用于连贯多个张量的函数。如果须要频繁地执行 torch.cat 操作,可能会影响程序的性能。以下是一些优化 torch.cat 速度的办法:

  1. 事后调配输入张量空间

当应用 torch.cat 连贯多个张量时,每次操作都会重新分配输入张量的空间,这会导致额定的内存调配和拷贝。如果已知输入张量的形态,能够在执行 torch.cat 操作之前先事后调配输入张量的空间,防止反复分配内存。

例如,假如要连贯三个形态为 (3, 64, 64) 的张量,能够先创立一个形态为 (9, 64, 64) 的输入张量,并将三个输出张量复制到输入张量的不同局部:

import torch

x1 = torch.randn(3, 64, 64)
x2 = torch.randn(3, 64, 64)
x3 = torch.randn(3, 64, 64)

out = torch.empty(9, 64, 64)
out[:3] = x1
out[3:6] = x2
out[6:] = x3

这样能够防止 torch.cat 操作中的反复内存调配和拷贝,进步程序性能。

  1. 应用 torch.stack 代替 torch.cat

torch.stack 是另一个用于连贯多个张量的函数,它与 torch.cat 相似,但会在新的维度上重叠输出张量。在一些状况下,应用 torch.stack 能够比 torch.cat 更快地连贯张量。

例如,假如要连贯三个形态为 (3, 64, 64) 的张量,能够应用 torch.stack 在新的维度上重叠三个张量,造成一个形态为 (3, 3, 64, 64) 的输入张量:

import torch

x1 = torch.randn(3, 64, 64)
x2 = torch.randn(3, 64, 64)
x3 = torch.randn(3, 64, 64)

out = torch.stack([x1, x2, x3])

须要留神的是,应用 torch.stack 可能会减少输入张量的维度,须要依据具体情况抉择适合的操作。

  1. 应用 GPU 减速

如果应用 GPU 进行张量操作,能够减速 torch.cat 操作的速度。能够应用 tensor.to(device) 将张量挪动到 GPU 上,并在操作完结后应用 tensor.to('cpu') 将张量移回 CPU。

例如,假如应用 GPU 进行张量操作:

import torch

x1 = torch.randn(3, 64, 64).cuda()
x2 = torch.randn(3, 64, 64).cuda()
x3 = torch.randn(3, 64, 64).cuda()

out = torch.cat([x1, x2, x3], dim=0)
out = out.to('cpu')

以上是一些优化 torch.cat 速度的办法,依据具体情况抉择适合的办法能够无效进步程序性能。

正文完
 0