乐趣区

关于人工智能:Pytorch中获取模型摘要的3种方法

在 pytorch 中获取模型的可训练和不可训练的参数,层名称,内核大小和数量。

Pytorch nn.Module 类中没有提供像与 Keras 那样的能够计算模型中可训练和不可训练的参数的数量并显示模型摘要的办法。所以在这篇文章中,我将总结我晓得三种办法来计算 Pytorch 模型中可训练和不可训练的参数的数量。

间接手写代码

最间接的方法就是咱们本人手写代码代码实现这个性能,所以这里我本人实现了一个函数,函数中为了丑陋所以引入了 PrettyTable 的包

 from prettytable import PrettyTable
 
 def count_parameters(model):
     table = PrettyTable([“Modules”,“Parameters”])
     total_params = 0
     for name, parameter in model.named_parameters():
         if not parameter.requires_grad: continue
         params = parameter.numel()
         table.add_row([name, params])
         total_params+=params
     print(table)
     print(f”Total Trainable Params: {total_params}”)
     return total_params

咱们拿 RESNET18 为例,以上函数的输入如下:

 +------------------------------+------------+ 
 |           Modules            | Parameters | 
 +------------------------------+------------+ 
 |         conv1.weight         |    9408    | 
 |          bn1.weight          |     64     | 
 |           bn1.bias           |     64     | 
 |    layer1.0.conv1.weight     |   36864    | 
 |     layer1.0.bn1.weight      |     64     | 
 |      layer1.0.bn1.bias       |     64     |
 .
 .
 .
 |          fc.weight           |   512000   | 
 |           fc.bias            |    1000    | 
 +------------------------------+------------+ 
 Total Trainable Params: 11689512

输入以参数为单位,能够看到模型中存在的每个参数的可训练参数,是不是和 keras 的根本一样。

torchsummary

torchsummary 呈现的时候的指标就是为了让 torch 有相似 keras 一样的打印模型参数的性能,它十分敌对并且非常简略。以后版本为 1.5.1,能够间接应用 pip 装置:

 pip install torchsummary

装置实现后即可应用,咱们还是以 resnet18 为例

 from torchsummary import summary
 model = torchvision.models.resnet18().cuda()

在应用时,咱们须要生成一个模型的输出变量,也就是模仿模型的前向流传的过程:

 summary(model, input_size = (3, 64, 64), batch_size = -1)

后果如下:

 — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — 
 Layer (type)               Output Shape                  Param # ================================================================ 
 Conv2d-1               [-1, 64, 112, 112]                  9,408 
 BatchNorm2d-2          [-1, 64, 112, 112]                    128 
 ReLU-3                 [-1, 64, 112, 112]                      0 
 MaxPool2d-4              [-1, 64, 56, 56]                      0 
 Conv2d-5                 [-1, 64, 56, 56]                 36,864
 .
 .
 .
 AdaptiveAvgPool2d-67      [-1, 512, 1, 1]                      0
 Linear-68                      [-1, 1000]                513,000 ================================================================
 Total params: 11,689,512 
 Trainable params: 11,689,512 
 Non-trainable params: 0 
 ----------------------------------------------------------------
 Input size (MB): 0.57 
 Forward/backward pass size (MB): 62.79 
 Params size (MB): 44.59 
 Estimated Total Size (MB): 107.96 
 ----------------------------------------------------------------

当初,如果你的根本模型有多个分支,每个分支都有不同的输出,例如

 class Model(torch.nn.Module):
     def __init__(self):
         super().__init__()
         self.resnet1 = torchvision.models.resnet18().cuda()
         self.resnet2 = torchvision.models.resnet18().cuda()
         self.resnet3 = torchvision.models.resnet18().cuda()
     
     def forward(self, *x):
         out1 = self.resnet1(x[0])
         out2 = self.resnet2(x[1])
         out3 = self.resnet3(x[2])
         out = torch.cat([out1, out2, out3], dim = 0)
         return out

那么就须要这样:

 summary(Model().cuda(), input_size = [(3, 64, 64)]*3)

该输入将与前一个类似,但会有点凌乱,因为 torchsummary 将每个组成的 ResNet 模块的信息压缩到一个摘要中,而在两个间断模块的摘要之间没有任何适当的可辨别边界。

torchinfo

它看起来可能与 torchsummary 相似。但在我看来,它是我找到这三种办法中最好的。torchinfo 以后版本是 1.7.0,还是能够应用 pip 装置:

 pip install torchinfo

这个包也有一个名为 summary 的函数。但它有更多的参数。他的应用参数为 model (nn.module)、input_size (Sequence of Sizes)、input_data (Sequence of Tensors)、batch_dim (int)、cache_forward_pass (bool)、col_names (Iterable[str])、col_width (int)、depth (int)、device (torch.Device)、dtypes (List[torch.dtype])、mode (str)、row_settings (Iterable[str])、verbose (int)和 **kwargs。

参数很多,然而能够间接通过 (” input_size “,” output_size “,” num_params “,” kernel_size “,” mult_add “,” trainable “) 作为 col_names 参数来获取信息。

 import torchinfo
 torchinfo.summary(model, (3, 224, 224), batch_dim = 0, col_names = (“input_size”,“output_size”,“num_params”,“kernel_size”,“mult_adds”), verbose = 0)

须要阐明的是,如果不应用 Jupyter 或 Google Colab,须要将verbose 更改为 1。

上述代码段的输入看起来像这样

 =============================================================================================
 Layer (type:depth-idx)                   Input Shape               Output Shape              Param #                   Kernel Shape              Mult-Adds
 =============================================================================================
 ResNet                                   [1, 3, 224, 224]          [1, 1000]                 --                        --                        --
 ├─Conv2d: 1-1                            [1, 3, 224, 224]          [1, 64, 112, 112]         9,408                     [7, 7]                    118,013,952
 ├─BatchNorm2d: 1-2                       [1, 64, 112, 112]         [1, 64, 112, 112]         128                       --                        128
 ├─ReLU: 1-3                              [1, 64, 112, 112]         [1, 64, 112, 112]         --                        --                        --
 ├─MaxPool2d: 1-4                         [1, 64, 112, 112]         [1, 64, 56, 56]           --                        3                         --
 ├─Sequential: 1-5                        [1, 64, 56, 56]           [1, 64, 56, 56]           --                        --                        --
 │    └─BasicBlock: 2-1                   [1, 64, 56, 56]           [1, 64, 56, 56]           --                        --                        --
 │    │    └─Conv2d: 3-1                  [1, 64, 56, 56]           [1, 64, 56, 56]           36,864                    [3, 3]                    115,605,504
 │    │    └─BatchNorm2d: 3-2             [1, 64, 56, 56]           [1, 64, 56, 56]           128                       --                        128
 │    │    └─ReLU: 3-3                    [1, 64, 56, 56]           [1, 64, 56, 56]           --                        --                        --
 │    │    └─Conv2d: 3-4                  [1, 64, 56, 56]           [1, 64, 56, 56]           36,864                    [3, 3]                    115,605,504
 │    │    └─BatchNorm2d: 3-5             [1, 64, 56, 56]           [1, 64, 56, 56]           128                       --                        128
 │    │    └─ReLU: 3-6                    [1, 64, 56, 56]           [1, 64, 56, 56]           --                        --                        --
 │    └─BasicBlock: 2-2                   [1, 64, 56, 56]           [1, 64, 56, 56]           --                        --                        --
 │    │    └─Conv2d: 3-7                  [1, 64, 56, 56]           [1, 64, 56, 56]           36,864                    [3, 3]                    115,605,504
 │    │    └─BatchNorm2d: 3-8             [1, 64, 56, 56]           [1, 64, 56, 56]           128                       --                        128
 │    │    └─ReLU: 3-9                    [1, 64, 56, 56]           [1, 64, 56, 56]           --                        --                        --
 │    │    └─Conv2d: 3-10                 [1, 64, 56, 56]           [1, 64, 56, 56]           36,864                    [3, 3]                    115,605,504
 │    │    └─BatchNorm2d: 3-11            [1, 64, 56, 56]           [1, 64, 56, 56]           128                       --                        128
 │    │    └─ReLU: 3-12                   [1, 64, 56, 56]           [1, 64, 56, 56]           --                        --                        --
 ├─Sequential: 1-6                        [1, 64, 56, 56]           [1, 128, 28, 28]          --                        --                        --
 │    └─BasicBlock: 2-3                   [1, 64, 56, 56]           [1, 128, 28, 28]          --                        --                        --
 │    │    └─Conv2d: 3-13                 [1, 64, 56, 56]           [1, 128, 28, 28]          73,728                    [3, 3]                    57,802,752
 │    │    └─BatchNorm2d: 3-14            [1, 128, 28, 28]          [1, 128, 28, 28]          256                       --                        256
 .
 .
 .
 │    │    └─Conv2d: 3-49                 [1, 512, 7, 7]            [1, 512, 7, 7]            2,359,296                 [3, 3]                    115,605,504
 │    │    └─BatchNorm2d: 3-50            [1, 512, 7, 7]            [1, 512, 7, 7]            1,024                     --                        1,024
 │    │    └─ReLU: 3-51                   [1, 512, 7, 7]            [1, 512, 7, 7]            --                        --                        --
 ├─AdaptiveAvgPool2d: 1-9                 [1, 512, 7, 7]            [1, 512, 1, 1]            --                        --                        --
 ├─Linear: 1-10                           [1, 512]                  [1, 1000]                 513,000                   --                        513,000
 =============================================================================================
 Total params: 11,689,512
 Trainable params: 11,689,512
 Non-trainable params: 0
 Total mult-adds (G): 1.81
 =============================================================================================
 Input size (MB): 0.60
 Forward/backward pass size (MB): 39.75
 Params size (MB): 46.76
 Estimated Total Size (MB): 87.11
 =============================================================================================

再持续查看多分支模型

 torchinfo.summary(Model().cuda(), [(3, 64, 64)]*3, batch_dim = 0, col_names = (“input_size”,“output_size”,“num_params”,“kernel_size”,“mult_adds”), verbose = 0)

产生以下输入

 =============================================================================================
 Layer (type:depth-idx)                        Input Shape               Output Shape              Param #                   Kernel Shape              Mult-Adds
 =============================================================================================
 Model                                         [1, 3, 64, 64]            [1, 1000]                 --                        --                        --
 ├─ResNet: 1-1                                 [1, 3, 64, 64]            [1, 1000]                 --                        --                        --
 │    └─Conv2d: 2-1                            [1, 3, 64, 64]            [1, 64, 32, 32]           9,408                     [7, 7]                    9,633,792
 │    └─BatchNorm2d: 2-2                       [1, 64, 32, 32]           [1, 64, 32, 32]           128                       --                        128
 │    └─ReLU: 2-3                              [1, 64, 32, 32]           [1, 64, 32, 32]           --                        --                        --
 │    └─MaxPool2d: 2-4                         [1, 64, 32, 32]           [1, 64, 16, 16]           --                        3                         --
 │    └─Sequential: 2-5                        [1, 64, 16, 16]           [1, 64, 16, 16]           --                        --                        --
 │    │    └─BasicBlock: 3-1                   [1, 64, 16, 16]           [1, 64, 16, 16]           73,984                    --                        18,874,624
 │    │    └─BasicBlock: 3-2                   [1, 64, 16, 16]           [1, 64, 16, 16]           73,984                    --                        18,874,624
 │    └─Sequential: 2-6                        [1, 64, 16, 16]           [1, 128, 8, 8]            --                        --                        --
 │    │    └─BasicBlock: 3-3                   [1, 64, 16, 16]           [1, 128, 8, 8]            230,144                   --                        14,680,832
 │    │    └─BasicBlock: 3-4                   [1, 128, 8, 8]            [1, 128, 8, 8]            295,424                   --                        18,874,880
 │    └─Sequential: 2-7                        [1, 128, 8, 8]            [1, 256, 4, 4]            --                        --                        --
 │    │    └─BasicBlock: 3-5                   [1, 128, 8, 8]            [1, 256, 4, 4]            919,040                   --                        14,681,600
 │    │    └─BasicBlock: 3-6                   [1, 256, 4, 4]            [1, 256, 4, 4]            1,180,672                 --                        18,875,392
 │    └─Sequential: 2-8                        [1, 256, 4, 4]            [1, 512, 2, 2]            --                        --                        --
 │    │    └─BasicBlock: 3-7                   [1, 256, 4, 4]            [1, 512, 2, 2]            3,673,088                 --                        14,683,136
 │    │    └─BasicBlock: 3-8                   [1, 512, 2, 2]            [1, 512, 2, 2]            4,720,640                 --                        18,876,416
 │    └─AdaptiveAvgPool2d: 2-9                 [1, 512, 2, 2]            [1, 512, 1, 1]            --                        --                        --
 │    └─Linear: 2-10                           [1, 512]                  [1, 1000]                 513,000                   --                        513,000
 ├─ResNet: 1-2                                 [1, 3, 64, 64]            [1, 1000]                 --                        --                        --
 │    └─Conv2d: 2-11                           [1, 3, 64, 64]            [1, 64, 32, 32]           9,408                     [7, 7]                    9,633,792
 │    └─BatchNorm2d: 2-12                      [1, 64, 32, 32]           [1, 64, 32, 32]           128                       --                        128
 │    └─ReLU: 2-13                             [1, 64, 32, 32]           [1, 64, 32, 32]           --                        --                        --
 │    └─MaxPool2d: 2-14                        [1, 64, 32, 32]           [1, 64, 16, 16]           --                        3                         --
 │    └─Sequential: 2-15                       [1, 64, 16, 16]           [1, 64, 16, 16]           --                        --                        --
 │    │    └─BasicBlock: 3-9                   [1, 64, 16, 16]           [1, 64, 16, 16]           73,984                    --                        18,874,624
 │    │    └─BasicBlock: 3-10                  [1, 64, 16, 16]           [1, 64, 16, 16]           73,984                    --                        18,874,624
 │    └─Sequential: 2-16                       [1, 64, 16, 16]           [1, 128, 8, 8]            --                        --                        --
 │    │    └─BasicBlock: 3-11                  [1, 64, 16, 16]           [1, 128, 8, 8]            230,144                   --                        14,680,832
 │    │    └─BasicBlock: 3-12                  [1, 128, 8, 8]            [1, 128, 8, 8]            295,424                   --                        18,874,880
 │    └─Sequential: 2-17                       [1, 128, 8, 8]            [1, 256, 4, 4]            --                        --                        --
 │    │    └─BasicBlock: 3-13                  [1, 128, 8, 8]            [1, 256, 4, 4]            919,040                   --                        14,681,600
 │    │    └─BasicBlock: 3-14                  [1, 256, 4, 4]            [1, 256, 4, 4]            1,180,672                 --                        18,875,392
 │    └─Sequential: 2-18                       [1, 256, 4, 4]            [1, 512, 2, 2]            --                        --                        --
 │    │    └─BasicBlock: 3-15                  [1, 256, 4, 4]            [1, 512, 2, 2]            3,673,088                 --                        14,683,136
 │    │    └─BasicBlock: 3-16                  [1, 512, 2, 2]            [1, 512, 2, 2]            4,720,640                 --                        18,876,416
 │    └─AdaptiveAvgPool2d: 2-19                [1, 512, 2, 2]            [1, 512, 1, 1]            --                        --                        --
 │    └─Linear: 2-20                           [1, 512]                  [1, 1000]                 513,000                   --                        513,000
 ├─ResNet: 1-3                                 [1, 3, 64, 64]            [1, 1000]                 --                        --                        --
 │    └─Conv2d: 2-21                           [1, 3, 64, 64]            [1, 64, 32, 32]           9,408                     [7, 7]                    9,633,792
 │    └─BatchNorm2d: 2-22                      [1, 64, 32, 32]           [1, 64, 32, 32]           128                       --                        128
 │    └─ReLU: 2-23                             [1, 64, 32, 32]           [1, 64, 32, 32]           --                        --                        --
 │    └─MaxPool2d: 2-24                        [1, 64, 32, 32]           [1, 64, 16, 16]           --                        3                         --
 │    └─Sequential: 2-25                       [1, 64, 16, 16]           [1, 64, 16, 16]           --                        --                        --
 │    │    └─BasicBlock: 3-17                  [1, 64, 16, 16]           [1, 64, 16, 16]           73,984                    --                        18,874,624
 │    │    └─BasicBlock: 3-18                  [1, 64, 16, 16]           [1, 64, 16, 16]           73,984                    --                        18,874,624
 │    └─Sequential: 2-26                       [1, 64, 16, 16]           [1, 128, 8, 8]            --                        --                        --
 │    │    └─BasicBlock: 3-19                  [1, 64, 16, 16]           [1, 128, 8, 8]            230,144                   --                        14,680,832
 │    │    └─BasicBlock: 3-20                  [1, 128, 8, 8]            [1, 128, 8, 8]            295,424                   --                        18,874,880
 │    └─Sequential: 2-27                       [1, 128, 8, 8]            [1, 256, 4, 4]            --                        --                        --
 │    │    └─BasicBlock: 3-21                  [1, 128, 8, 8]            [1, 256, 4, 4]            919,040                   --                        14,681,600
 │    │    └─BasicBlock: 3-22                  [1, 256, 4, 4]            [1, 256, 4, 4]            1,180,672                 --                        18,875,392
 │    └─Sequential: 2-28                       [1, 256, 4, 4]            [1, 512, 2, 2]            --                        --                        --
 │    │    └─BasicBlock: 3-23                  [1, 256, 4, 4]            [1, 512, 2, 2]            3,673,088                 --                        14,683,136
 │    │    └─BasicBlock: 3-24                  [1, 512, 2, 2]            [1, 512, 2, 2]            4,720,640                 --                        18,876,416
 │    └─AdaptiveAvgPool2d: 2-29                [1, 512, 2, 2]            [1, 512, 1, 1]            --                        --                        --
 │    └─Linear: 2-30                           [1, 512]                  [1, 1000]                 513,000                   --                        513,000
 =============================================================================================
 Total params: 35,068,536
 Trainable params: 35,068,536
 Non-trainable params: 0
 Total mult-adds (M): 445.71
 =============================================================================================
 Input size (MB): 0.15
 Forward/backward pass size (MB): 9.76
 Params size (MB): 140.27
 Estimated Total Size (MB): 150.18
 =============================================================================================

能够看到depth 参数的默认值为 3。并且在可视化方向上,多分支被从新进行了组织并且以层次结构形式出现,所以很容易辨别,所以他的成果要比 torchsummary 好很多。

https://avoid.overfit.cn/post/bfed756d1d5147a89f6d8d911b6d29dd

作者:Siladittya Manna

退出移动版