<>前言


在使用Pytorch训练模型的时候,经常会有在GPU上保存模型然后再CPU上运行的需求,在实验的过程中发现在多GPU上训练的Pytorch模型是不能在CPU上直接运行的,几次遇到了这种问题,这里研究和记录一下。

<>模型的保存与加载

例如我们创建了一个模型:
model = MyVggNet()
如果使用多GPU训练,我们需要使用这行代码:
model = nn.DataParallel(model).cuda()
执行这个代码之后,model就不在是我们原来的模型,而是相当于在我们原来的模型外面加了一层支持GPU运行的外壳,这时候真正的模型对象为:real_model
= model.module,
所以我们在保存模型的时候注意,如果保存的时候是否带有这层加的外壳,如果保存的时候带有的话,加载的时候也是带有的,如果保存的是真实的模型,加载的也是真是的模型。这里我建议保存真是的模型,因为加了module壳的模型在CPU上是不能运行的。
Pytorch有多种保存模型的方式,使用哪种进行保存,就要使用对应的加载方式。保存的时候模型的后缀名是无所谓的。
Pytorch官方的加载和保存模型的方式有两种:

* 保存和加载整个模型。这种方式再重新加载的时候不需要自定义网络结构,保存时已经把网络结构保存了下来,比较死板不能调整网络结构。 torch.save(
model_object, 'model.pkl') model = torch.load('model.pkl')
*
仅保存和加载模型参数(推荐使用)。这种方式再重新加载的时候需要自己定义网络,并且其中的参数名称与结构要与保存的模型中的一致(可以是部分网络,比如只使用VGG的前几层),相对灵活,便于对网络进行修改。
torch.save(model_object.state_dict(), 'params.pkl') model_object.load_state_dict
(torch.load('params.pkl'))
<>模型保存与加载对应方式

<>1. 第一种方式

保存使用:
real_model = model.module torch.save(real_model.state_dict(),os.path.join(args.
save_path,"cos_mnist_"+str(epoch+1)+"_weight.pth"))
cpu上加载使用:
args.weight=checkpoint/cos_mnist_10_weight.pth map_location = lambda storage,
loc: storage model.load_state_dict(torch.load(args.weight,map_location=
map_location))
<>2. 第二种方式

保存使用:
real_model = model.module save_model(real_model, os.path.join(args.save_path,
"cos_mnist_"+str(epoch+1)+"_weight_cpu.pth")) # 自定义的函数 def save_model(model,
filename): state = model.state_dict() for key in state: state[key] = state[key].
clone().cpu() torch.save(state, filename)
cpu上加载使用:
args.weight=checkpoint/cos_mnist_10_weight_cpu.pth model.load_state_dict(torch.
load(args.weight))
<>3. 第三种方式

保存使用:
real_model = model.module torch.save(real_model, os.path.join(args.save_path,
"cos_mnist_"+str(epoch+1)+"_whole.pth"))
cpu上加载使用:
args.weight=checkpoint/cos_mnist_10_whole.pth map_location = lambda storage,
loc: storage model = torch.load(args.weight,map_location=map_location)
<>参考文献

* pytorch学习笔记(五):保存和加载模型
<https://blog.csdn.net/u012436149/article/details/68948816/>
* Pytorch之GPU模型加载在CPU上
<https://blog.csdn.net/hardbird123/article/details/80549815>
* PyTorch使用cpu调用gpu训练的模型
<https://blog.csdn.net/c654528593/article/details/81539441>

友情链接
KaDraw流程图
API参考文档
OK工具箱
云服务器优惠
阿里云优惠券
腾讯云优惠券
华为云优惠券
站点信息
问题反馈
邮箱:ixiaoyang8@qq.com
QQ群:637538335
关注微信