应仅保存模型权重state_dict而非整个模型对象,因其不依赖类定义、支持跨环境加载,且需配合eval()模式、正确后缀、map_location及独立管理优化器状态。

保存模型权重用 torch.save(model.state_dict(), path),别存整个模型对象
直接保存 model 本身看似方便,但会把类定义、模块结构甚至训练时的临时变量一并固化,导致加载时对代码结构极度敏感——换一个文件路径、改一个 import 名、甚至只是升级了 PyTorch 版本,就可能报 AttributeError: 'dict' object has no attribute 'forward' 或更模糊的 Missing key 错误。
真正稳定、可迁移的做法是只序列化权重:用 model.state_dict() 提取参数字典,再用 torch.save() 写入磁盘。它不依赖模型类的具体实现,只要加载时能重建出结构一致的模型实例,就能安全 load_state_dict()。
- 保存前确保模型在
eval()模式(避免BatchNorm和Dropout的训练态参数被意外保存) - 路径建议用
.pt或.pth后缀,不要用.pkl——虽然底层都是 pickle,但后缀影响工具链识别(如 Hugging Facefrom_pretrained默认找.bin或.pth) - 如果模型用了
nn.DataParallel,保存前记得用model.module.state_dict(),否则键名会多出module.前缀,后续单卡加载会不匹配
加载权重必须先初始化模型,再调用 load_state_dict()
不能跳过模型构造直接“反序列化”权重。PyTorch 不提供从 state_dict 自动还原网络结构的能力 —— 它只负责把字典里的值填进已有模型的对应参数位置。
常见错误是写成 model = torch.load(path),这会尝试 unpickle 整个对象,一旦环境不一致就崩溃;或者漏掉 model.load_state_dict(...),结果模型还是随机初始化的权重。
立即学习“Python免费学习笔记(深入)”;
- 加载后务必调用
model.eval()(如果用于推理),否则BatchNorm仍按训练模式运行,输出不稳定 - 键名不匹配时默认报错,可用
strict=False忽略多余或缺失的键,但要小心:missing keys可能意味着模型结构没对齐,unexpected keys可能是保存时混入了优化器状态 - 若保存时用了
torch.compile(),加载后需重新 compile,state_dict不包含编译后的图信息
保存/加载优化器状态也用 state_dict,但必须和模型分开管理
优化器也有自己的 state_dict(含动量、累积梯度等),常和模型权重一起保存用于断点续训。但它和模型的 state_dict 是两个独立字典,不能合并或混用。
典型做法是打包成一个 Python 字典再保存:
torch.save({
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'epoch': epoch,
'loss': loss,
}, path)
加载时也需分别取出、分别加载:
checkpoint = torch.load(path) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
-
optimizer.state_dict()包含张量,保存路径必须支持二进制写入(不能用文本模式打开) - 不同优化器的
state_dict结构差异大,比如AdamW有exp_avg_sq,而SGD没有;切换优化器时不能复用旧的optimizer_state_dict - 如果用了
torch.cuda.amp.GradScaler,它的state_dict也要一并保存,否则混合精度训练会中断
跨设备加载要注意 map_location,否则 CPU 加载 GPU 模型会卡死
在 GPU 上训练保存的模型,如果直接在 CPU 环境下加载,PyTorch 默认尝试把所有张量映射回原设备(即 CUDA),结果报 RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False。
解决方法是在 torch.load() 里加 map_location 参数,显式指定目标设备:
- CPU 加载:用
torch.load(path, map_location='cpu') - 指定某张 GPU:用
torch.load(path, map_location='cuda:1') - 自动适配当前设备:用
torch.load(path, map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
这个参数只影响 torch.load() 阶段,不影响后续 model.to(device) —— 但如果你先 load_state_dict() 再 to(device),中间参数会经历一次设备拷贝,多一次内存开销。










