关于pytorch:20226-torch中使用l1-l2正则的写法

间接本人写

model是要正则的模型,reg_type抉择'l1'还是l2,coef是系数。

def regularization(model:nn.Module, reg_type,coef):
    int_type=int(reg_type[1])
    reg_loss = 0
    for module in model.modules():
        for param in module.parameters():
            reg_loss+=torch.norm(param,int_type)
        
    return reg_loss*coef

代码是一个小例子,对哪个module进行正则,这都能够本人筛选,不用对每一个module都正则。

优化器中增加

一个是Adam或者AdamW优化器外面有weight_decay参数,那个是l2的正则系数

评论

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

这个站点使用 Akismet 来减少垃圾评论。了解你的评论数据如何被处理