内容导读:特征提取是图像处理过程中常须要用到的一种办法,其成果好坏对模型的泛化能力有至关重要的影响。
本文首发自微信公众号「PyTorch 开发者社区」。
特征提取(Feature extraction)在机器学习、模式识别和图像处理中利用宽泛。
它从初始的一组测量数据开始,建构出提供信息且不冗余的派生值,即特征值,从而促成后续的学习和泛化步骤。
在应用 PyTorch 进行模型训练的过程中,常常须要提取模型中间层的特色。解决这个问题能够用到 3 种办法。
对中间层进行特征提取的 3 大办法
1、借助模型类的属性传递
办法:批改 forward 函数,通过增加一行代码将 feature 赋值给 self 变量,即 _self.feature_map = feature_,而后打印输出即可。
备注:实用于仅提取中间层特色,不须要提取梯度的状况。
代码示例:
# Define a Convolutional Neural Network
class
Net(nn.Module):
def __init__(self, kernel_size=5, n_filters=16, n_layers=3):
xxx
def forward(self, x):
x = self.body(self.head(x))
self.featuremap1 = x.detach() # 外围代码
return F.relu(self.fc(x))
model_ft = Net()
train_model(model_ft)
feature_output1 = model_ft.featuremap1.transpose(1,0).cpu()
2、借助 hook 机制
hook 是一个可调用对象,它能够在不批改主代码的前提下插入业务。PyTorch 中的 hook 包含三种:
torch.autograd.Variable.register_hook
torch.nn.Module.register_backward_hook
torch.nn.Module.register_forward_hook
第一个是针对 Variable 对象的,后两个是针对 nn.Module 对象的。
办法:在调用阶段对 Module 应用 forward_hook 函数,能够取得所需梯度或特色。
备注:较为简单、功能完善,须要对 PyTorch 有肯定水平的理解。
3、借助 torchextractor
torchextractor 是一个独立 Python 包,具备跟 nn.Module 性能相似的提取器,只需提供模块名称,就能够在 PyTorch 中对中间层进行特征提取。
与应用 forward_hook 进行中间层特征提取相比,torchextractor 更像是一个包装程序(wrapper),不像 torchvision IntermediateLayerGetter 有那么多的 _assumption_。
在性能方面 torchextractor 次要劣势在于反对嵌套模块(nested module)、自定义缓存操作,而且与 ONNX 兼容。
torchextractor 极大简化了在 PyTorch 中进行特征提取的流程,这防止了大量代码的粘贴复制,也不须要重写 forward 函数,它对初学者更敌对,可用性也更强。
torchextractor 上手实际
装置
pip install torchextractor # stable
pip install git+https://github.com/antoinebrl/torchextractor.git # latest
要求
Python 3.6 及以上版本
Torch 1.4.0 及以上版本
用法
import torch
import torchvision
import torchextractor as tx
model = torchvision.models.resnet18(pretrained=True)
model = tx.Extractor(model, ["layer1", "layer2", "layer3", "layer4"])
dummy_input = torch.rand(7, 3, 224, 224)
model_output, features = model(dummy_input)
feature_shapes = {name: f.shape for name, f in features.items()}
print(feature_shapes)
# {# 'layer1': torch.Size([1, 64, 56, 56]),
# 'layer2': torch.Size([1, 128, 28, 28]),
# 'layer3': torch.Size([1, 256, 14, 14]),
# 'layer4': torch.Size([1, 512, 7, 7]),
# }
残缺文档请查看:
https://github.com/antoinebrl…
以上就是本期汇总的 3 个对中间层进行特征提取的办法,如果你有更好的解决思路,或者其余想要理解的 Pytorch 相干问题,欢送在下方留言或发私信。
参考:
https://www.reddit.com/r/Mach…
https://www.zhihu.com/questio…