关于pytorch:pytorch并行化常见bug

25次阅读

共计 409 个字符,预计需要花费 2 分钟才能阅读完成。

state_dict = torch.load(opts.checkpoint)
try:
    trainer.net.load_state_dict(state_dict['net_param'])
except Exception:
    trainer.net = torch.nn.DataParallel(trainer.net)
    trainer.net.load_state_dict(state_dict['net_param'])

This is for dealing a checkpoint trained in parallel.

try: 
    out = trainer.net.forward()
except:
    out = trainer.net.module.forward()

Simialarly, the net need to be transformed to module for being compatible with paralelly trained model.

正文完
 0