共计 1384 个字符,预计需要花费 4 分钟才能阅读完成。
pytorch 代码仓库
pytorch 在 19 年 11 月份的时候合入了这部分剪枝的代码。pytorch 提供一些直接可用的 api,用户只需要传入需要剪枝的 module 实例和需要剪枝的参数名字,系统自动帮助完成剪枝操作,看起来接口挺简单。比如 def random_structured(module, name, amount, dim)
pytorch 支持的几种类型的剪枝策略:
详细分析
- pytorch 提供了一个剪枝的抽象基类‘‘class BasePruningMethod(ABC)’,所有剪枝策略都需要继承该基类,并重载部分函数就可以了
- 一般情况下需要重载__init__和 compute_mask 方法,__call__, apply_mask, apply, prune 和 remove 不需要重载,例如官方提供的 RandomUnstructured 剪枝方法
- 基类实现的 6 个方法:
- 剪枝的 API 接口,可以看到支持用户自定义的剪枝 mask,接口为 custom_from_mask
- API 的实现,使用 classmethod 的方法,剪枝策略的实例化在框架内部完成,不需要用户实例化
-
剪枝的大只过程:
- 根据用户选择的剪枝 API 生成对应的策略实例,此时会判断需要做剪枝操作的 module 上是否已经挂有前向回调函数,没有则生成新的,有了就在老的上面添加,并且生成 PruningContainer。从这里可以看出,对于同一个 module 使用多个剪枝策略时,pytorch 通过 PruningContainer 来对剪枝策略进行管理。PruningContainer 本身也是继承自 BasePruningMethod。同时设置前向计算的回调,便于后续训练时调用。
- 接着根据用户输入的 module 和 name,找到对应的参数 tensor。如果是第一次剪枝,那么需要生成_orig 结尾的 tensor,然后删除原始的 module 上的 tensor。如 name 为 bias,那么生成 bias_orig 存起来,然后删除 module.bias 属性。
- 获取 defaultmask,然后调用 method.computemask 生成当前策略的 mask 值。生成的 mask 会被存在特定的缓存 module.register_buffer(name + “_mask”, mask)。这里的 compute_mask 可能是两种情况:如果只有一个策略,那么调用的时候对应剪枝策略的 compute_mask 方法,如果一个 module 有多个剪枝策略组合,那么调用的应该是 PruningContainer 的 compute_mask
![file](/img/bVbHRbW)
4. 执行剪枝,保存剪枝结果到 module 的属性,注册训练时的剪枝回调函数,剪枝完成。新的 mask 应用在 orig 的 tensor 上面生成新的 tensor 保存的对应的 name 属性
![file](/img/bVbHRbX)
- remove 接口
pytorch 还提供各类一个 remove 接口,目的是把之前的剪枝结果持久化,具体操作就是删除之前生成的跟剪枝相关的缓存或者是回调 hook 接口,设置被剪枝的 name 参数(如 bias)为最后一次训练的值。
-
自己写一个剪枝策略接口也是可以的:
- 先写一个剪枝策略类继承 BasePruningMethod
- 然后重载基类的 compute_mask 方法,写自己的计算 mask 方法
官方完整教程在这里
正文完
发表至: tensorflow
2020-06-01