Pytorch模型保留与提取

Pytorch模型保留与提取

本篇笔记次要对应于莫凡Pytorch中的3.4节。次要讲了如何应用Pytorch保留和提取咱们的神经网络。
在Pytorch中,网络的存储次要应用torch.save函数来实现。
咱们将通过两种形式展现模型的保留和提取。
第一种保留形式是保留整个模型,在从新提取时间接加载整个模型。第二种保留办法是只保留模型的参数,这种形式只保留了参数,而不会保留模型的构造等信息。
两种形式各有优缺点。

保留残缺模型不须要晓得网络的构造,一次性保留一次性读入。毛病是模型比拟大时耗时较长,保留的文件也大。
而只保留参数的形式存储快捷,保留的文件也小一些,但毛病是失落了网络的构造信息,复原模型时须要提前建设一个特定构造的网络再读入参数。
以下应用代码展现。

数据生成与展现

import torchimport torch.nn.functional as Fimport matplotlib.pyplot as plt

复制代码
这里还是生成一组带有噪声的y=x2y=x^{2}y=x2数据进行回归拟合。

# torch.manual_seed(1)    # reproducible# fake datax = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)  # x data (tensor), shape=(100, 1)y = x.pow(2) + 0.2*torch.rand(x.size())  # noisy y data (tensor), shape=(100, 1)

复制代码

根本网络搭建与保留

咱们应用nn.Sequential模块来疾速搭建一个网络实现回归操作,网络由两层Linear层和两头的激活层ReLU组成。咱们设置输入输出的维度为1,两头暗藏层变量的维度为10,以放慢训练。
这里应用两种形式进行保留。

def save():    # save net1    net1 = torch.nn.Sequential(        torch.nn.Linear(1, 10),        torch.nn.ReLU(),        torch.nn.Linear(10, 1)    )    optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)    loss_func = torch.nn.MSELoss()        for step in range(100):        prediction = net1(x)        loss = loss_func(prediction, y)        optimizer.zero_grad()        loss.backward()        optimizer.step()            # plot result    plt.figure(1, figsize=(10, 3))    plt.subplot(131)    plt.title('Net1')    plt.scatter(x.data.numpy(), y.data.numpy())    plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)    plt.savefig("./img/05_save.png")            torch.save(net1, 'net.pkl')                        # entire network    torch.save(net1.state_dict(), 'net_params.pkl')    # parameters

复制代码
在这个save函数中,咱们首先应用nn.Sequential模块构建了一个根底的两层神经网络。而后对其进行训练,展现训练后果。之后应用两种形式进行保留。
第一种形式间接保留整个网络,代码为

torch.save(net1, 'net.pkl')                        # entire network复制代码第二种形式只保留网络参数,代码为torch.save(net1.state_dict(), 'net_params.pkl')    # parameters

复制代码

对保留的模型进行提取复原

这里咱们为两种不同存储形式保留的模型别离定义复原提取的函数
首先是对整个网络的提取。间接应用torch.load就能够,无需其余额定操作。

def restore_net():    # 提取神经网络    net2 = torch.load('net.pkl')    prediction = net2(x)        # plot result    plt.subplot(132)    plt.title('Net2')    plt.scatter(x.data.numpy(), y.data.numpy())    plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)    plt.savefig("./img/05_res_net.png")

复制代码
而对于参数的读取,咱们首先须要先搭建好一个与之前保留的模型雷同架构的网络,而后应用这个网络的load_state_dict办法进行参数读取和复原。以下展现了应用参数形式读取网络的示例:

def restore_params():    # 提取神经网络    net3 = torch.nn.Sequential(        torch.nn.Linear(1, 10),        torch.nn.ReLU(),        torch.nn.Linear(10, 1)    )    net3.load_state_dict(torch.load('net_params.pkl'))    prediction = net3(x)        # plot result    plt.subplot(133)    plt.title('Net3')    plt.scatter(x.data.numpy(), y.data.numpy())    plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)    plt.savefig("./img/05_res_para.png")    plt.show()

复制代码

比照不同提取办法的成果

接下来咱们比照一下这两种办法的提取成果

# save net1save()# restore entire net (may slow)restore_net()# restore only the net parametersrestore_params()