2. Pytorch:模型保存与读取

2.1. 简单

import torch
## save
torch.save(model, 'model.pkl')
## load
model = torch.load('model.pkl')

这种方法存储的模型包括了模型框架及模型参数等,存取的 pkl 文件较大。

2.2. 详细

模型除了本身的框架、参数信息,还应包括训练的信息,比如训练迭代次数、优化器参数等。

 1import torch
 2import shutil
 3
 4## save
 5def save_checkpoint(state, is_best, save_path, filename):
 6  filename = os.path.join(save_path, filename)
 7  torch.save(state, filename)
 8  if is_best:
 9    bestname = os.path.join(save_path, 'model_best.pth.tar')
10    shutil.copyfile(filename, bestname)
11
12save_checkpoint({
13        'epoch': cur_epoch,
14        'state_dict': model.state_dict(),
15        'best_prec': best_prec,
16        'loss_train': loss_train,
17        'optimizer': optimizer.state_dict(),
18      }, is_best, save_path, 'epoch-{}_checkpoint.pth.tar'.format(cur_epoch))
19
20## load
21def load_checkpoint(checkpoint, model, optimizer):
22  """ loads state into model and optimizer and returns:
23      epoch, best_precision, loss_train[]
24      e.g., model = alexnet(pretrained=False)
25  """
26  if os.path.isfile(load_path):
27      print("=> loading checkpoint '{}'".format(load_path))
28      checkpoint = torch.load(load_path)
29      epoch = checkpoint['epoch']
30      best_prec = checkpoint['best_prec']
31      loss_train = checkpoint['loss_train']
32      model.load_state_dict(checkpoint['state_dict'])
33      optimizer.load_state_dict(checkpoint['optimizer'])
34      print("=> loaded checkpoint '{}' (epoch {})"
35            .format(epoch, checkpoint['epoch']))
36      return epoch, best_prec, loss_train
37  else:
38      print("=> no checkpoint found at '{}'".format(load_path))
39      # epoch, best_precision, loss_train
40      return 1, 0, []

2.3. 导入部分参数

当我们只需要从 state_dict() 导入部分模型参数时,可以采用如下方法:

1# args has the model name, num classes and other irrelevant stuff
2>>> pretrained_state = model_zoo.load_url(model_names[args.arch])
3>>> model_state = my_model.state_dict()
4>>> pretrained_state = { k:v for k,v in pretrained_state.iteritems() if k in model_state and v.size() == model_state[k].size() }
5>>> model_state.update(pretrained_state)
6>>> my_model.load_state_dict(model_state)

Note

state_dict() 的参数是包含设备信息的,如果 torch.save 保存的是 GPU 上的模型的状态,则其参数是在 GPU 上的;相应地, torch.load 会默认地将这些参数加载到 GPU 上。为了避免 GPU 显存耗尽,可以使用 torch.load(checkpoint, map_location='cpu') 先将这些参数加载到 CPU 上,然后再进行 load_state_dict

2.4. 参考资料

  1. Saving and loading a model in Pytorch?

  1. How to load part of pre trained model?

  1. Serialization