共计 340 个字符,预计需要花费 1 分钟才能阅读完成。
间接本人写
model 是要正则的模型,reg_type 抉择 'l1'
还是l2
,coef 是系数。
def regularization(model:nn.Module, reg_type,coef):
int_type=int(reg_type[1])
reg_loss = 0
for module in model.modules():
for param in module.parameters():
reg_loss+=torch.norm(param,int_type)
return reg_loss*coef
代码是一个小例子,对哪个 module 进行正则,这都能够本人筛选,不用对每一个 module 都正则。
优化器中增加
一个是 Adam 或者 AdamW 优化器外面有 weight_decay 参数,那个是 l2 的正则系数
正文完