起因
应用torch的optimizer增加了2组parameter,传参进入FATE的trainer后,optimizer被扭转,且FATE框架无提醒。
代码差不多是上面这样:
# optimizer中退出2组优化参数(param)optimizer = torch.optim.SGD([{'params':base, 'lr':0.1*train_args['lr']},\ {'params':head, 'lr':train_args['lr']}])nn_component = HomoNN(name='sanet', model=model, # model loss=loss, optimizer=optimizer, # 传入trainer后 dataset=dataset_param, # dataset trainer=TrainerParam(trainer_name='sa_trainer', cuda=True, checkpoint_save_freqs=1, **params), torch_seed=100, # random seed )# optimizer的param_group在trainer中就只变成1组了,其余的不见了。
github上反馈给社区了:我提的issue
[外链图片转存失败,源站可能有防盗链机制,倡议将图片保留下来间接上传(img-I9vA6ec9-1688611983896)(https://user-images.githubusercontent.com/31330044/251026326-...)]
解决
解决办法是不应用FATE给的接口,而本人间接在trainer外面提供optimizer。
class Trainer(): def init(opt_name='sgd'): xxxx def train(): self.optimizer = make_optimizer(self.model, self.opt_name)
能够在trainer中本人实现,提交工作时不提供optimizer参数即可.