Skip to content

PyTorch 训练恢复时解决设备不一致错误

问题描述

在使用 PyTorch 进行 GPU 训练时,保存检查点后重新加载继续训练,可能会遇到以下错误:

python
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

这个错误表明在计算过程中,部分张量位于 GPU (cuda:0) 上,而部分张量仍位于 CPU 上,导致设备不一致。

根本原因

设备不一致问题通常由以下几个原因引起:

  1. 模型和设备未正确关联:模型未正确移动到目标设备
  2. 优化器状态存储位置问题:优化器状态与模型不在同一设备上
  3. 输入数据未转移到设备:数据加载后忘记转移到 GPU
  4. 损失函数参数未同步:某些损失函数的参数(如权重)未同步到设备

解决方案

方案一:统一设备管理(推荐)

使用统一的设备管理方法,确保所有组件在同一设备上:

python
import torch

# 定义设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 将模型转移到设备
model = YourModel().to(device)

# 将数据转移到设备
for input, target in train_loader:
    input, target = input.to(device), target.to(device)
    # 训练代码...

方案二:正确保存和加载检查点

保存检查点时,PyTorch 会记录模型和优化器的状态,但不记录设备信息。加载时需要重新指定设备:

python
# 保存检查点
torch.save({
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss,
}, 'checkpoint.pth')

# 加载检查点
checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

# 重新将模型和优化器转移到设备
model = model.to(device)

重要提示

优化器应在模型转移到最终设备后创建,因为优化器会引用模型的参数位置。

方案三:处理特殊情况的设备转移

某些特定场景需要额外注意设备同步:

1. 分词器输出需要手动转移

python
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("model_name").to(device)
tokenizer = AutoTokenizer.from_pretrained("model_name")

# 必须将分词结果也转移到设备
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)

2. 损失函数的参数需要转移

python
# 错误的做法:权重张量在CPU上
criterion = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([3]))

# 正确的做法:权重张量也在设备上
criterion = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([3]).to(device))

3. 自定义损失函数需要转移

python
class CustomLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.weight = torch.Tensor([1.0, 2.0])  # 需要转移到设备
        
    def forward(self, input, target):
        # 计算损失
        return loss

# 使用前确保权重在正确设备上
criterion = CustomLoss().to(device)

方案四:MPS设备支持(Apple Silicon)

对于使用Apple Silicon芯片的Mac用户:

python
# 检查MPS可用性
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("使用MPS设备")
else:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
model = model.to(device)

最佳实践

  1. 设备一致性检查:在训练前验证所有组件在同一设备上
  2. 设备选择函数:创建统一的设备选择逻辑
  3. 检查点验证:加载检查点后验证设备一致性
  4. 异常处理:添加设备不匹配时的恢复机制
python
def check_device_consistency(model, data, criterion):
    """检查所有组件是否在同一设备上"""
    model_device = next(model.parameters()).device
    data_device = data.device if hasattr(data, 'device') else 'cpu'
    
    if model_device != data_device:
        print(f"警告: 模型在 {model_device}, 数据在 {data_device}")
        return False
    return True

# 使用示例
input, target = next(iter(train_loader))
input, target = input.to(device), target.to(device)
if not check_device_consistency(model, input, criterion):
    # 处理设备不一致情况
    model = model.to(device)
    criterion = criterion.to(device)

总结

PyTorch训练恢复时的设备不一致错误通常是由于模型、数据和优化器状态未正确同步到同一设备导致的。通过统一的设备管理策略、正确的检查点处理流程以及特殊情况下的设备同步,可以有效避免这一问题。

关键要点

  • 创建模型后立即使用.to(device)指定设备
  • 优化器应在模型转移到设备后创建
  • 所有输入数据和损失函数参数都需要同步到设备
  • 加载检查点后需要重新确认设备一致性

遵循这些最佳实践,可以确保PyTorch训练过程的设备一致性,避免因设备不匹配导致的运行时错误。