在pytorch中获取模型的可训练和不可训练的参数,层名称,内核大小和数量。

Pytorch nn.Module 类中没有提供像与Keras那样的能够计算模型中可训练和不可训练的参数的数量并显示模型摘要的办法 。所以在这篇文章中,我将总结我晓得三种办法来计算Pytorch模型中可训练和不可训练的参数的数量。

间接手写代码

最间接的方法就是咱们本人手写代码代码实现这个性能,所以这里我本人实现了一个函数,函数中为了丑陋所以引入了PrettyTable的包

 from prettytable import PrettyTable  def count_parameters(model):     table = PrettyTable([“Modules”, “Parameters”])     total_params = 0     for name, parameter in model.named_parameters():         if not parameter.requires_grad: continue         params = parameter.numel()         table.add_row([name, params])         total_params+=params     print(table)     print(f”Total Trainable Params: {total_params}”)     return total_params

咱们拿RESNET18为例,以上函数的输入如下:

 +------------------------------+------------+  |           Modules            | Parameters |  +------------------------------+------------+  |         conv1.weight         |    9408    |  |          bn1.weight          |     64     |  |           bn1.bias           |     64     |  |    layer1.0.conv1.weight     |   36864    |  |     layer1.0.bn1.weight      |     64     |  |      layer1.0.bn1.bias       |     64     | . . . |          fc.weight           |   512000   |  |           fc.bias            |    1000    |  +------------------------------+------------+  Total Trainable Params: 11689512

输入以参数为单位,能够看到模型中存在的每个参数的可训练参数,是不是和keras的根本一样。

torchsummary

torchsummary呈现的时候的指标就是为了让torch有相似keras一样的打印模型参数的性能,它十分敌对并且非常简略。以后版本为1.5.1,能够间接应用pip装置:

 pip install torchsummary

装置实现后即可应用,咱们还是以resnet18为例

 from torchsummary import summary model = torchvision.models.resnet18().cuda()

在应用时,咱们须要生成一个模型的输出变量,也就是模仿模型的前向流传的过程:

 summary(model, input_size = (3, 64, 64), batch_size = -1)

后果如下:

 — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — — —  Layer (type)               Output Shape                  Param # ================================================================  Conv2d-1               [-1, 64, 112, 112]                  9,408  BatchNorm2d-2          [-1, 64, 112, 112]                    128  ReLU-3                 [-1, 64, 112, 112]                      0  MaxPool2d-4              [-1, 64, 56, 56]                      0  Conv2d-5                 [-1, 64, 56, 56]                 36,864 . . . AdaptiveAvgPool2d-67      [-1, 512, 1, 1]                      0 Linear-68                      [-1, 1000]                513,000 ================================================================ Total params: 11,689,512  Trainable params: 11,689,512  Non-trainable params: 0  ---------------------------------------------------------------- Input size (MB): 0.57  Forward/backward pass size (MB): 62.79  Params size (MB): 44.59  Estimated Total Size (MB): 107.96  ----------------------------------------------------------------

当初,如果你的根本模型有多个分支,每个分支都有不同的输出,例如

 class Model(torch.nn.Module):     def __init__(self):         super().__init__()         self.resnet1 = torchvision.models.resnet18().cuda()         self.resnet2 = torchvision.models.resnet18().cuda()         self.resnet3 = torchvision.models.resnet18().cuda()          def forward(self, *x):         out1 = self.resnet1(x[0])         out2 = self.resnet2(x[1])         out3 = self.resnet3(x[2])         out = torch.cat([out1, out2, out3], dim = 0)         return out

那么就须要这样:

 summary(Model().cuda(), input_size = [(3, 64, 64)]*3)

该输入将与前一个类似,但会有点凌乱,因为torchsummary将每个组成的ResNet模块的信息压缩到一个摘要中,而在两个间断模块的摘要之间没有任何适当的可辨别边界。

torchinfo

它看起来可能与torchsummary相似。但在我看来,它是我找到这三种办法中最好的。torchinfo以后版本是1.7.0,还是能够应用pip装置:

 pip install torchinfo

这个包也有一个名为summary的函数。但它有更多的参数。他的应用参数为model (nn.module)、input_size (Sequence of Sizes)、input_data (Sequence of Tensors)、batch_dim (int)、cache_forward_pass (bool)、col_names (Iterable[str])、col_width (int)、depth (int)、device (torch.Device)、dtypes (List[torch.dtype])、mode (str)、row_settings (Iterable[str])、verbose (int)和**kwargs。

参数很多,然而能够间接通过(" input_size ", " output_size ", " num_params ", " kernel_size ", " mult_add ", " trainable ")作为col_names参数来获取信息。

 import torchinfo torchinfo.summary(model, (3, 224, 224), batch_dim = 0, col_names = (“input_size”, “output_size”, “num_params”, “kernel_size”, “mult_adds”), verbose = 0)

须要阐明的是,如果不应用Jupyter或Google Colab,须要将verbose 更改为1。

上述代码段的输入看起来像这样

 ============================================================================================= Layer (type:depth-idx)                   Input Shape               Output Shape              Param #                   Kernel Shape              Mult-Adds ============================================================================================= ResNet                                   [1, 3, 224, 224]          [1, 1000]                 --                        --                        -- ├─Conv2d: 1-1                            [1, 3, 224, 224]          [1, 64, 112, 112]         9,408                     [7, 7]                    118,013,952 ├─BatchNorm2d: 1-2                       [1, 64, 112, 112]         [1, 64, 112, 112]         128                       --                        128 ├─ReLU: 1-3                              [1, 64, 112, 112]         [1, 64, 112, 112]         --                        --                        -- ├─MaxPool2d: 1-4                         [1, 64, 112, 112]         [1, 64, 56, 56]           --                        3                         -- ├─Sequential: 1-5                        [1, 64, 56, 56]           [1, 64, 56, 56]           --                        --                        -- │    └─BasicBlock: 2-1                   [1, 64, 56, 56]           [1, 64, 56, 56]           --                        --                        -- │    │    └─Conv2d: 3-1                  [1, 64, 56, 56]           [1, 64, 56, 56]           36,864                    [3, 3]                    115,605,504 │    │    └─BatchNorm2d: 3-2             [1, 64, 56, 56]           [1, 64, 56, 56]           128                       --                        128 │    │    └─ReLU: 3-3                    [1, 64, 56, 56]           [1, 64, 56, 56]           --                        --                        -- │    │    └─Conv2d: 3-4                  [1, 64, 56, 56]           [1, 64, 56, 56]           36,864                    [3, 3]                    115,605,504 │    │    └─BatchNorm2d: 3-5             [1, 64, 56, 56]           [1, 64, 56, 56]           128                       --                        128 │    │    └─ReLU: 3-6                    [1, 64, 56, 56]           [1, 64, 56, 56]           --                        --                        -- │    └─BasicBlock: 2-2                   [1, 64, 56, 56]           [1, 64, 56, 56]           --                        --                        -- │    │    └─Conv2d: 3-7                  [1, 64, 56, 56]           [1, 64, 56, 56]           36,864                    [3, 3]                    115,605,504 │    │    └─BatchNorm2d: 3-8             [1, 64, 56, 56]           [1, 64, 56, 56]           128                       --                        128 │    │    └─ReLU: 3-9                    [1, 64, 56, 56]           [1, 64, 56, 56]           --                        --                        -- │    │    └─Conv2d: 3-10                 [1, 64, 56, 56]           [1, 64, 56, 56]           36,864                    [3, 3]                    115,605,504 │    │    └─BatchNorm2d: 3-11            [1, 64, 56, 56]           [1, 64, 56, 56]           128                       --                        128 │    │    └─ReLU: 3-12                   [1, 64, 56, 56]           [1, 64, 56, 56]           --                        --                        -- ├─Sequential: 1-6                        [1, 64, 56, 56]           [1, 128, 28, 28]          --                        --                        -- │    └─BasicBlock: 2-3                   [1, 64, 56, 56]           [1, 128, 28, 28]          --                        --                        -- │    │    └─Conv2d: 3-13                 [1, 64, 56, 56]           [1, 128, 28, 28]          73,728                    [3, 3]                    57,802,752 │    │    └─BatchNorm2d: 3-14            [1, 128, 28, 28]          [1, 128, 28, 28]          256                       --                        256 . . . │    │    └─Conv2d: 3-49                 [1, 512, 7, 7]            [1, 512, 7, 7]            2,359,296                 [3, 3]                    115,605,504 │    │    └─BatchNorm2d: 3-50            [1, 512, 7, 7]            [1, 512, 7, 7]            1,024                     --                        1,024 │    │    └─ReLU: 3-51                   [1, 512, 7, 7]            [1, 512, 7, 7]            --                        --                        -- ├─AdaptiveAvgPool2d: 1-9                 [1, 512, 7, 7]            [1, 512, 1, 1]            --                        --                        -- ├─Linear: 1-10                           [1, 512]                  [1, 1000]                 513,000                   --                        513,000 ============================================================================================= Total params: 11,689,512 Trainable params: 11,689,512 Non-trainable params: 0 Total mult-adds (G): 1.81 ============================================================================================= Input size (MB): 0.60 Forward/backward pass size (MB): 39.75 Params size (MB): 46.76 Estimated Total Size (MB): 87.11 =============================================================================================

再持续查看多分支模型

 torchinfo.summary(Model().cuda(), [(3, 64, 64)]*3, batch_dim = 0, col_names = (“input_size”, “output_size”, “num_params”, “kernel_size”, “mult_adds”), verbose = 0)

产生以下输入

 ============================================================================================= Layer (type:depth-idx)                        Input Shape               Output Shape              Param #                   Kernel Shape              Mult-Adds ============================================================================================= Model                                         [1, 3, 64, 64]            [1, 1000]                 --                        --                        -- ├─ResNet: 1-1                                 [1, 3, 64, 64]            [1, 1000]                 --                        --                        -- │    └─Conv2d: 2-1                            [1, 3, 64, 64]            [1, 64, 32, 32]           9,408                     [7, 7]                    9,633,792 │    └─BatchNorm2d: 2-2                       [1, 64, 32, 32]           [1, 64, 32, 32]           128                       --                        128 │    └─ReLU: 2-3                              [1, 64, 32, 32]           [1, 64, 32, 32]           --                        --                        -- │    └─MaxPool2d: 2-4                         [1, 64, 32, 32]           [1, 64, 16, 16]           --                        3                         -- │    └─Sequential: 2-5                        [1, 64, 16, 16]           [1, 64, 16, 16]           --                        --                        -- │    │    └─BasicBlock: 3-1                   [1, 64, 16, 16]           [1, 64, 16, 16]           73,984                    --                        18,874,624 │    │    └─BasicBlock: 3-2                   [1, 64, 16, 16]           [1, 64, 16, 16]           73,984                    --                        18,874,624 │    └─Sequential: 2-6                        [1, 64, 16, 16]           [1, 128, 8, 8]            --                        --                        -- │    │    └─BasicBlock: 3-3                   [1, 64, 16, 16]           [1, 128, 8, 8]            230,144                   --                        14,680,832 │    │    └─BasicBlock: 3-4                   [1, 128, 8, 8]            [1, 128, 8, 8]            295,424                   --                        18,874,880 │    └─Sequential: 2-7                        [1, 128, 8, 8]            [1, 256, 4, 4]            --                        --                        -- │    │    └─BasicBlock: 3-5                   [1, 128, 8, 8]            [1, 256, 4, 4]            919,040                   --                        14,681,600 │    │    └─BasicBlock: 3-6                   [1, 256, 4, 4]            [1, 256, 4, 4]            1,180,672                 --                        18,875,392 │    └─Sequential: 2-8                        [1, 256, 4, 4]            [1, 512, 2, 2]            --                        --                        -- │    │    └─BasicBlock: 3-7                   [1, 256, 4, 4]            [1, 512, 2, 2]            3,673,088                 --                        14,683,136 │    │    └─BasicBlock: 3-8                   [1, 512, 2, 2]            [1, 512, 2, 2]            4,720,640                 --                        18,876,416 │    └─AdaptiveAvgPool2d: 2-9                 [1, 512, 2, 2]            [1, 512, 1, 1]            --                        --                        -- │    └─Linear: 2-10                           [1, 512]                  [1, 1000]                 513,000                   --                        513,000 ├─ResNet: 1-2                                 [1, 3, 64, 64]            [1, 1000]                 --                        --                        -- │    └─Conv2d: 2-11                           [1, 3, 64, 64]            [1, 64, 32, 32]           9,408                     [7, 7]                    9,633,792 │    └─BatchNorm2d: 2-12                      [1, 64, 32, 32]           [1, 64, 32, 32]           128                       --                        128 │    └─ReLU: 2-13                             [1, 64, 32, 32]           [1, 64, 32, 32]           --                        --                        -- │    └─MaxPool2d: 2-14                        [1, 64, 32, 32]           [1, 64, 16, 16]           --                        3                         -- │    └─Sequential: 2-15                       [1, 64, 16, 16]           [1, 64, 16, 16]           --                        --                        -- │    │    └─BasicBlock: 3-9                   [1, 64, 16, 16]           [1, 64, 16, 16]           73,984                    --                        18,874,624 │    │    └─BasicBlock: 3-10                  [1, 64, 16, 16]           [1, 64, 16, 16]           73,984                    --                        18,874,624 │    └─Sequential: 2-16                       [1, 64, 16, 16]           [1, 128, 8, 8]            --                        --                        -- │    │    └─BasicBlock: 3-11                  [1, 64, 16, 16]           [1, 128, 8, 8]            230,144                   --                        14,680,832 │    │    └─BasicBlock: 3-12                  [1, 128, 8, 8]            [1, 128, 8, 8]            295,424                   --                        18,874,880 │    └─Sequential: 2-17                       [1, 128, 8, 8]            [1, 256, 4, 4]            --                        --                        -- │    │    └─BasicBlock: 3-13                  [1, 128, 8, 8]            [1, 256, 4, 4]            919,040                   --                        14,681,600 │    │    └─BasicBlock: 3-14                  [1, 256, 4, 4]            [1, 256, 4, 4]            1,180,672                 --                        18,875,392 │    └─Sequential: 2-18                       [1, 256, 4, 4]            [1, 512, 2, 2]            --                        --                        -- │    │    └─BasicBlock: 3-15                  [1, 256, 4, 4]            [1, 512, 2, 2]            3,673,088                 --                        14,683,136 │    │    └─BasicBlock: 3-16                  [1, 512, 2, 2]            [1, 512, 2, 2]            4,720,640                 --                        18,876,416 │    └─AdaptiveAvgPool2d: 2-19                [1, 512, 2, 2]            [1, 512, 1, 1]            --                        --                        -- │    └─Linear: 2-20                           [1, 512]                  [1, 1000]                 513,000                   --                        513,000 ├─ResNet: 1-3                                 [1, 3, 64, 64]            [1, 1000]                 --                        --                        -- │    └─Conv2d: 2-21                           [1, 3, 64, 64]            [1, 64, 32, 32]           9,408                     [7, 7]                    9,633,792 │    └─BatchNorm2d: 2-22                      [1, 64, 32, 32]           [1, 64, 32, 32]           128                       --                        128 │    └─ReLU: 2-23                             [1, 64, 32, 32]           [1, 64, 32, 32]           --                        --                        -- │    └─MaxPool2d: 2-24                        [1, 64, 32, 32]           [1, 64, 16, 16]           --                        3                         -- │    └─Sequential: 2-25                       [1, 64, 16, 16]           [1, 64, 16, 16]           --                        --                        -- │    │    └─BasicBlock: 3-17                  [1, 64, 16, 16]           [1, 64, 16, 16]           73,984                    --                        18,874,624 │    │    └─BasicBlock: 3-18                  [1, 64, 16, 16]           [1, 64, 16, 16]           73,984                    --                        18,874,624 │    └─Sequential: 2-26                       [1, 64, 16, 16]           [1, 128, 8, 8]            --                        --                        -- │    │    └─BasicBlock: 3-19                  [1, 64, 16, 16]           [1, 128, 8, 8]            230,144                   --                        14,680,832 │    │    └─BasicBlock: 3-20                  [1, 128, 8, 8]            [1, 128, 8, 8]            295,424                   --                        18,874,880 │    └─Sequential: 2-27                       [1, 128, 8, 8]            [1, 256, 4, 4]            --                        --                        -- │    │    └─BasicBlock: 3-21                  [1, 128, 8, 8]            [1, 256, 4, 4]            919,040                   --                        14,681,600 │    │    └─BasicBlock: 3-22                  [1, 256, 4, 4]            [1, 256, 4, 4]            1,180,672                 --                        18,875,392 │    └─Sequential: 2-28                       [1, 256, 4, 4]            [1, 512, 2, 2]            --                        --                        -- │    │    └─BasicBlock: 3-23                  [1, 256, 4, 4]            [1, 512, 2, 2]            3,673,088                 --                        14,683,136 │    │    └─BasicBlock: 3-24                  [1, 512, 2, 2]            [1, 512, 2, 2]            4,720,640                 --                        18,876,416 │    └─AdaptiveAvgPool2d: 2-29                [1, 512, 2, 2]            [1, 512, 1, 1]            --                        --                        -- │    └─Linear: 2-30                           [1, 512]                  [1, 1000]                 513,000                   --                        513,000 ============================================================================================= Total params: 35,068,536 Trainable params: 35,068,536 Non-trainable params: 0 Total mult-adds (M): 445.71 ============================================================================================= Input size (MB): 0.15 Forward/backward pass size (MB): 9.76 Params size (MB): 140.27 Estimated Total Size (MB): 150.18 =============================================================================================

能够看到depth 参数的默认值为3。并且在可视化方向上,多分支被从新进行了组织并且以层次结构形式出现,所以很容易辨别,所以他的成果要比torchsummary好很多。

https://avoid.overfit.cn/post/bfed756d1d5147a89f6d8d911b6d29dd

作者:Siladittya Manna