乐趣区

关于人工智能:FATE联邦学习-optimizer在FATE的自定义trainer中被改变

起因

应用 torch 的 optimizer 增加了 2 组 parameter,传参进入 FATE 的 trainer 后,optimizer 被扭转,且 FATE 框架无提醒。

代码差不多是上面这样:

# optimizer 中退出 2 组优化参数(param)optimizer = torch.optim.SGD([{'params':base, 'lr':0.1*train_args['lr']},\
                              {'params':head, 'lr':train_args['lr']}])


nn_component = HomoNN(name='sanet',
                      model=model, # model
                      loss=loss,
                      optimizer=optimizer, # 传入 trainer 后
                      dataset=dataset_param,  # dataset
                      trainer=TrainerParam(trainer_name='sa_trainer', cuda=True, checkpoint_save_freqs=1, **params),
                      torch_seed=100, # random seed
                      
                      )

# optimizer 的 param_group 在 trainer 中就只变成 1 组了,其余的不见了。

github 上反馈给社区了:我提的 issue
[外链图片转存失败, 源站可能有防盗链机制, 倡议将图片保留下来间接上传 (img-I9vA6ec9-1688611983896)(https://user-images.githubusercontent.com/31330044/251026326-…)]

解决

解决办法是不应用 FATE 给的接口,而本人间接在 trainer 外面提供 optimizer。

class Trainer():
    def init(opt_name='sgd'):
        xxxx
    def train():
        self.optimizer = make_optimizer(self.model, self.opt_name)

能够在 trainer 中本人实现,提交工作时不提供 optimizer 参数即可.

退出移动版