关于pytorch:pytorch并行化常见bug

state_dict = torch.load(opts.checkpoint)
try:
    trainer.net.load_state_dict(state_dict['net_param'])
except Exception:
    trainer.net = torch.nn.DataParallel(trainer.net)
    trainer.net.load_state_dict(state_dict['net_param'])

This is for dealing a checkpoint trained in parallel.

try: 
    out = trainer.net.forward()
except:
    out = trainer.net.module.forward()

Simialarly, the net need to be transformed to module for being compatible with paralelly trained model.

评论

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

这个站点使用 Akismet 来减少垃圾评论。了解你的评论数据如何被处理