关于pytorch:Pytorch模型保存与提取

44次阅读

共计 2511 个字符,预计需要花费 7 分钟才能阅读完成。

Pytorch 模型保留与提取

Pytorch 模型保留与提取

本篇笔记次要对应于莫凡 Pytorch 中的 3.4 节。次要讲了如何应用 Pytorch 保留和提取咱们的神经网络。
在 Pytorch 中,网络的存储次要应用 torch.save 函数来实现。
咱们将通过两种形式展现模型的保留和提取。
第一种保留形式是保留整个模型,在从新提取时间接加载整个模型。第二种保留办法是只保留模型的参数,这种形式只保留了参数,而不会保留模型的构造等信息。
两种形式各有优缺点。

保留残缺模型不须要晓得网络的构造,一次性保留一次性读入。毛病是模型比拟大时耗时较长,保留的文件也大。
而只保留参数的形式存储快捷,保留的文件也小一些,但毛病是失落了网络的构造信息,复原模型时须要提前建设一个特定构造的网络再读入参数。
以下应用代码展现。

数据生成与展现

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

复制代码
这里还是生成一组带有噪声的 y =x2y=x^{2}y=x2 数据进行回归拟合。

# torch.manual_seed(1)    # reproducible

# fake data
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)  # x data (tensor), shape=(100, 1)
y = x.pow(2) + 0.2*torch.rand(x.size())  # noisy y data (tensor), shape=(100, 1)

复制代码

根本网络搭建与保留

咱们应用 nn.Sequential 模块来疾速搭建一个网络实现回归操作,网络由两层 Linear 层和两头的激活层 ReLU 组成。咱们设置输入输出的维度为 1,两头暗藏层变量的维度为 10,以放慢训练。
这里应用两种形式进行保留。

def save():
    # save net1
    net1 = torch.nn.Sequential(torch.nn.Linear(1, 10),
        torch.nn.ReLU(),
        torch.nn.Linear(10, 1)
    )
    optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)
    loss_func = torch.nn.MSELoss()
    
    for step in range(100):
        prediction = net1(x)
        loss = loss_func(prediction, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    # plot result
    plt.figure(1, figsize=(10, 3))
    plt.subplot(131)
    plt.title('Net1')
    plt.scatter(x.data.numpy(), y.data.numpy())
    plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
    plt.savefig("./img/05_save.png")
        
    torch.save(net1, 'net.pkl')                        # entire network
    torch.save(net1.state_dict(), 'net_params.pkl')    # parameters

复制代码
在这个 save 函数中,咱们首先应用 nn.Sequential 模块构建了一个根底的两层神经网络。而后对其进行训练,展现训练后果。之后应用两种形式进行保留。
第一种形式间接保留整个网络,代码为

torch.save(net1, 'net.pkl')                        # entire network
复制代码
第二种形式只保留网络参数,代码为
torch.save(net1.state_dict(), 'net_params.pkl')    # parameters

复制代码

对保留的模型进行提取复原

这里咱们为两种不同存储形式保留的模型别离定义复原提取的函数
首先是对整个网络的提取。间接应用 torch.load 就能够,无需其余额定操作。

def restore_net():
    # 提取神经网络
    net2 = torch.load('net.pkl')
    prediction = net2(x)
    
    # plot result
    plt.subplot(132)
    plt.title('Net2')
    plt.scatter(x.data.numpy(), y.data.numpy())
    plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
    plt.savefig("./img/05_res_net.png")

复制代码
而对于参数的读取,咱们首先须要先搭建好一个与之前保留的模型雷同架构的网络,而后应用这个网络的 load_state_dict 办法进行参数读取和复原。以下展现了应用参数形式读取网络的示例:

def restore_params():
    # 提取神经网络
    net3 = torch.nn.Sequential(torch.nn.Linear(1, 10),
        torch.nn.ReLU(),
        torch.nn.Linear(10, 1)
    )
    net3.load_state_dict(torch.load('net_params.pkl'))
    prediction = net3(x)
    
    # plot result
    plt.subplot(133)
    plt.title('Net3')
    plt.scatter(x.data.numpy(), y.data.numpy())
    plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
    plt.savefig("./img/05_res_para.png")
    plt.show()

复制代码

比照不同提取办法的成果

接下来咱们比照一下这两种办法的提取成果

# save net1
save()

# restore entire net (may slow)
restore_net()

# restore only the net parameters
restore_params()

正文完
 0