Pytorch模型保留与提取
Pytorch模型保留与提取
本篇笔记次要对应于莫凡Pytorch中的3.4节。次要讲了如何应用Pytorch保留和提取咱们的神经网络。
在Pytorch中,网络的存储次要应用torch.save函数来实现。
咱们将通过两种形式展现模型的保留和提取。
第一种保留形式是保留整个模型,在从新提取时间接加载整个模型。第二种保留办法是只保留模型的参数,这种形式只保留了参数,而不会保留模型的构造等信息。
两种形式各有优缺点。
保留残缺模型不须要晓得网络的构造,一次性保留一次性读入。毛病是模型比拟大时耗时较长,保留的文件也大。
而只保留参数的形式存储快捷,保留的文件也小一些,但毛病是失落了网络的构造信息,复原模型时须要提前建设一个特定构造的网络再读入参数。
以下应用代码展现。
数据生成与展现
import torchimport torch.nn.functional as Fimport matplotlib.pyplot as plt
复制代码
这里还是生成一组带有噪声的y=x2y=x^{2}y=x2数据进行回归拟合。
# torch.manual_seed(1) # reproducible# fake datax = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # x data (tensor), shape=(100, 1)y = x.pow(2) + 0.2*torch.rand(x.size()) # noisy y data (tensor), shape=(100, 1)
复制代码
根本网络搭建与保留
咱们应用nn.Sequential模块来疾速搭建一个网络实现回归操作,网络由两层Linear层和两头的激活层ReLU组成。咱们设置输入输出的维度为1,两头暗藏层变量的维度为10,以放慢训练。
这里应用两种形式进行保留。
def save(): # save net1 net1 = torch.nn.Sequential( torch.nn.Linear(1, 10), torch.nn.ReLU(), torch.nn.Linear(10, 1) ) optimizer = torch.optim.SGD(net1.parameters(), lr=0.5) loss_func = torch.nn.MSELoss() for step in range(100): prediction = net1(x) loss = loss_func(prediction, y) optimizer.zero_grad() loss.backward() optimizer.step() # plot result plt.figure(1, figsize=(10, 3)) plt.subplot(131) plt.title('Net1') plt.scatter(x.data.numpy(), y.data.numpy()) plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) plt.savefig("./img/05_save.png") torch.save(net1, 'net.pkl') # entire network torch.save(net1.state_dict(), 'net_params.pkl') # parameters
复制代码
在这个save函数中,咱们首先应用nn.Sequential模块构建了一个根底的两层神经网络。而后对其进行训练,展现训练后果。之后应用两种形式进行保留。
第一种形式间接保留整个网络,代码为
torch.save(net1, 'net.pkl') # entire network复制代码第二种形式只保留网络参数,代码为torch.save(net1.state_dict(), 'net_params.pkl') # parameters
复制代码
对保留的模型进行提取复原
这里咱们为两种不同存储形式保留的模型别离定义复原提取的函数
首先是对整个网络的提取。间接应用torch.load就能够,无需其余额定操作。
def restore_net(): # 提取神经网络 net2 = torch.load('net.pkl') prediction = net2(x) # plot result plt.subplot(132) plt.title('Net2') plt.scatter(x.data.numpy(), y.data.numpy()) plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) plt.savefig("./img/05_res_net.png")
复制代码
而对于参数的读取,咱们首先须要先搭建好一个与之前保留的模型雷同架构的网络,而后应用这个网络的load_state_dict办法进行参数读取和复原。以下展现了应用参数形式读取网络的示例:
def restore_params(): # 提取神经网络 net3 = torch.nn.Sequential( torch.nn.Linear(1, 10), torch.nn.ReLU(), torch.nn.Linear(10, 1) ) net3.load_state_dict(torch.load('net_params.pkl')) prediction = net3(x) # plot result plt.subplot(133) plt.title('Net3') plt.scatter(x.data.numpy(), y.data.numpy()) plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) plt.savefig("./img/05_res_para.png") plt.show()
复制代码
比照不同提取办法的成果
接下来咱们比照一下这两种办法的提取成果
# save net1save()# restore entire net (may slow)restore_net()# restore only the net parametersrestore_params()