共计 2511 个字符,预计需要花费 7 分钟才能阅读完成。
Pytorch 模型保留与提取
Pytorch 模型保留与提取
本篇笔记次要对应于莫凡 Pytorch 中的 3.4 节。次要讲了如何应用 Pytorch 保留和提取咱们的神经网络。
在 Pytorch 中,网络的存储次要应用 torch.save 函数来实现。
咱们将通过两种形式展现模型的保留和提取。
第一种保留形式是保留整个模型,在从新提取时间接加载整个模型。第二种保留办法是只保留模型的参数,这种形式只保留了参数,而不会保留模型的构造等信息。
两种形式各有优缺点。
保留残缺模型不须要晓得网络的构造,一次性保留一次性读入。毛病是模型比拟大时耗时较长,保留的文件也大。
而只保留参数的形式存储快捷,保留的文件也小一些,但毛病是失落了网络的构造信息,复原模型时须要提前建设一个特定构造的网络再读入参数。
以下应用代码展现。
数据生成与展现
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
复制代码
这里还是生成一组带有噪声的 y =x2y=x^{2}y=x2 数据进行回归拟合。
# torch.manual_seed(1) # reproducible
# fake data
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # x data (tensor), shape=(100, 1)
y = x.pow(2) + 0.2*torch.rand(x.size()) # noisy y data (tensor), shape=(100, 1)
复制代码
根本网络搭建与保留
咱们应用 nn.Sequential 模块来疾速搭建一个网络实现回归操作,网络由两层 Linear 层和两头的激活层 ReLU 组成。咱们设置输入输出的维度为 1,两头暗藏层变量的维度为 10,以放慢训练。
这里应用两种形式进行保留。
def save():
# save net1
net1 = torch.nn.Sequential(torch.nn.Linear(1, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 1)
)
optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)
loss_func = torch.nn.MSELoss()
for step in range(100):
prediction = net1(x)
loss = loss_func(prediction, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# plot result
plt.figure(1, figsize=(10, 3))
plt.subplot(131)
plt.title('Net1')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
plt.savefig("./img/05_save.png")
torch.save(net1, 'net.pkl') # entire network
torch.save(net1.state_dict(), 'net_params.pkl') # parameters
复制代码
在这个 save 函数中,咱们首先应用 nn.Sequential 模块构建了一个根底的两层神经网络。而后对其进行训练,展现训练后果。之后应用两种形式进行保留。
第一种形式间接保留整个网络,代码为
torch.save(net1, 'net.pkl') # entire network
复制代码
第二种形式只保留网络参数,代码为
torch.save(net1.state_dict(), 'net_params.pkl') # parameters
复制代码
对保留的模型进行提取复原
这里咱们为两种不同存储形式保留的模型别离定义复原提取的函数
首先是对整个网络的提取。间接应用 torch.load 就能够,无需其余额定操作。
def restore_net():
# 提取神经网络
net2 = torch.load('net.pkl')
prediction = net2(x)
# plot result
plt.subplot(132)
plt.title('Net2')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
plt.savefig("./img/05_res_net.png")
复制代码
而对于参数的读取,咱们首先须要先搭建好一个与之前保留的模型雷同架构的网络,而后应用这个网络的 load_state_dict 办法进行参数读取和复原。以下展现了应用参数形式读取网络的示例:
def restore_params():
# 提取神经网络
net3 = torch.nn.Sequential(torch.nn.Linear(1, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 1)
)
net3.load_state_dict(torch.load('net_params.pkl'))
prediction = net3(x)
# plot result
plt.subplot(133)
plt.title('Net3')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
plt.savefig("./img/05_res_para.png")
plt.show()
复制代码
比照不同提取办法的成果
接下来咱们比照一下这两种办法的提取成果
# save net1
save()
# restore entire net (may slow)
restore_net()
# restore only the net parameters
restore_params()