共计 409 个字符,预计需要花费 2 分钟才能阅读完成。
state_dict = torch.load(opts.checkpoint)
try:
trainer.net.load_state_dict(state_dict['net_param'])
except Exception:
trainer.net = torch.nn.DataParallel(trainer.net)
trainer.net.load_state_dict(state_dict['net_param'])
This is for dealing a checkpoint trained in parallel.
try:
out = trainer.net.forward()
except:
out = trainer.net.module.forward()
Simialarly, the net need to be transformed to module
for being compatible with paralelly trained model.
正文完