PyTorch 训练恢复时解决设备不一致错误
问题描述
在使用 PyTorch 进行 GPU 训练时,保存检查点后重新加载继续训练,可能会遇到以下错误:
python
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
这个错误表明在计算过程中,部分张量位于 GPU (cuda:0) 上,而部分张量仍位于 CPU 上,导致设备不一致。
根本原因
设备不一致问题通常由以下几个原因引起:
- 模型和设备未正确关联:模型未正确移动到目标设备
- 优化器状态存储位置问题:优化器状态与模型不在同一设备上
- 输入数据未转移到设备:数据加载后忘记转移到 GPU
- 损失函数参数未同步:某些损失函数的参数(如权重)未同步到设备
解决方案
方案一:统一设备管理(推荐)
使用统一的设备管理方法,确保所有组件在同一设备上:
python
import torch
# 定义设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 将模型转移到设备
model = YourModel().to(device)
# 将数据转移到设备
for input, target in train_loader:
input, target = input.to(device), target.to(device)
# 训练代码...
方案二:正确保存和加载检查点
保存检查点时,PyTorch 会记录模型和优化器的状态,但不记录设备信息。加载时需要重新指定设备:
python
# 保存检查点
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
}, 'checkpoint.pth')
# 加载检查点
checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# 重新将模型和优化器转移到设备
model = model.to(device)
重要提示
优化器应在模型转移到最终设备后创建,因为优化器会引用模型的参数位置。
方案三:处理特殊情况的设备转移
某些特定场景需要额外注意设备同步:
1. 分词器输出需要手动转移
python
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("model_name").to(device)
tokenizer = AutoTokenizer.from_pretrained("model_name")
# 必须将分词结果也转移到设备
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
2. 损失函数的参数需要转移
python
# 错误的做法:权重张量在CPU上
criterion = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([3]))
# 正确的做法:权重张量也在设备上
criterion = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([3]).to(device))
3. 自定义损失函数需要转移
python
class CustomLoss(nn.Module):
def __init__(self):
super().__init__()
self.weight = torch.Tensor([1.0, 2.0]) # 需要转移到设备
def forward(self, input, target):
# 计算损失
return loss
# 使用前确保权重在正确设备上
criterion = CustomLoss().to(device)
方案四:MPS设备支持(Apple Silicon)
对于使用Apple Silicon芯片的Mac用户:
python
# 检查MPS可用性
if torch.backends.mps.is_available():
device = torch.device("mps")
print("使用MPS设备")
else:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
最佳实践
- 设备一致性检查:在训练前验证所有组件在同一设备上
- 设备选择函数:创建统一的设备选择逻辑
- 检查点验证:加载检查点后验证设备一致性
- 异常处理:添加设备不匹配时的恢复机制
python
def check_device_consistency(model, data, criterion):
"""检查所有组件是否在同一设备上"""
model_device = next(model.parameters()).device
data_device = data.device if hasattr(data, 'device') else 'cpu'
if model_device != data_device:
print(f"警告: 模型在 {model_device}, 数据在 {data_device}")
return False
return True
# 使用示例
input, target = next(iter(train_loader))
input, target = input.to(device), target.to(device)
if not check_device_consistency(model, input, criterion):
# 处理设备不一致情况
model = model.to(device)
criterion = criterion.to(device)
总结
PyTorch训练恢复时的设备不一致错误通常是由于模型、数据和优化器状态未正确同步到同一设备导致的。通过统一的设备管理策略、正确的检查点处理流程以及特殊情况下的设备同步,可以有效避免这一问题。
关键要点
- 创建模型后立即使用
.to(device)
指定设备 - 优化器应在模型转移到设备后创建
- 所有输入数据和损失函数参数都需要同步到设备
- 加载检查点后需要重新确认设备一致性
遵循这些最佳实践,可以确保PyTorch训练过程的设备一致性,避免因设备不匹配导致的运行时错误。