Resolving PyTorch CUDA-CPU Device Mismatch Errors
When resuming training with a saved PyTorch checkpoint, you might encounter the error: RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
. This error occurs when some tensors exist on the GPU while others remain on the CPU, preventing proper computation.
Common Causes and Solutions
1. Ensure Model and Data Are on the Same Device
The most frequent cause is inconsistent device placement between your model, data, and optimizer. Use .to(device)
consistently:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Move model to device
model = YourModel().to(device)
# Move data to device
for i, (input, target) in enumerate(train_loader):
input, target = input.to(device), target.to(device)
2. Handle Optimizer State Properly
Construct your optimizer after moving the model to the appropriate device:
# Correct order: model first, then optimizer
model = YourModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
When resuming training, ensure you move the model to the device before loading the optimizer state:
checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device) # Move model before optimizer
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
3. Check Tokenizers and Input Processing
For NLP models, ensure tokenized inputs are moved to the correct device:
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("model_name").to(device)
tokenizer = AutoTokenizer.from_pretrained("model_name")
# Don't forget to move tokenized inputs to device
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
4. Verify Loss Function Parameters
Some loss functions have parameters that need explicit device placement:
# Incorrect - tensor remains on CPU
criterion = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([3]))
# Correct - explicitly move to device
criterion = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([3]).to(device))
5. Apple MPS Device Support
For Apple Silicon (M1/M2) users, use the Metal Performance Shaders backend:
if torch.backends.mps.is_available():
device = torch.device("mps")
model.to(device)
Complete Training Setup Example
def setup_training():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Initialize model and move to device
model = YourModel().to(device)
# Initialize optimizer after model is on device
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# Initialize criterion (with device-aware parameters if needed)
criterion = YourCriterion().to(device)
return model, optimizer, criterion, device
def train_epoch(model, optimizer, criterion, train_loader, device):
model.train()
for input, target in train_loader:
# Move data to device
input, target = input.to(device), target.to(device)
# Training steps
optimizer.zero_grad()
output = model(input)
loss = criterion(output, target)
loss.backward()
optimizer.step()
Checkpoint Saving and Loading Best Practices
def save_checkpoint(model, optimizer, epoch, path):
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}, path)
def load_checkpoint(model, optimizer, path, device):
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device) # Critical: move model to device first
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
return checkpoint['epoch']
WARNING
Always move your model to the target device before loading the optimizer state. The optimizer expects its parameters to be on the same device they were during saving.
TIP
Use a consistent device management strategy throughout your code. Consider creating a device variable at the beginning and referencing it whenever moving tensors or models.
By following these practices, you can avoid device mismatch errors and ensure smooth training resumption across CPU and GPU environments.