关于pytorch:20226-torch中使用l1-l2正则的写法

28次阅读

共计 340 个字符,预计需要花费 1 分钟才能阅读完成。

间接本人写

model 是要正则的模型,reg_type 抉择 'l1' 还是l2,coef 是系数。

def regularization(model:nn.Module, reg_type,coef):
    int_type=int(reg_type[1])
    reg_loss = 0
    for module in model.modules():
        for param in module.parameters():
            reg_loss+=torch.norm(param,int_type)
        
    return reg_loss*coef

代码是一个小例子,对哪个 module 进行正则,这都能够本人筛选,不用对每一个 module 都正则。

优化器中增加

一个是 Adam 或者 AdamW 优化器外面有 weight_decay 参数,那个是 l2 的正则系数

正文完
 0