state_dict = torch.load(opts.checkpoint)try: trainer.net.load_state_dict(state_dict['net_param'])except Exception: trainer.net = torch.nn.DataParallel(trainer.net) trainer.net.load_state_dict(state_dict['net_param'])
This is for dealing a checkpoint trained in parallel.
try: out = trainer.net.forward()except: out = trainer.net.module.forward()
Simialarly, the net need to be transformed to module
for being compatible with paralelly trained model.