Skip to content

Resolving PyTorch CUDA-CPU Device Mismatch Errors

When resuming training with a saved PyTorch checkpoint, you might encounter the error: RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!. This error occurs when some tensors exist on the GPU while others remain on the CPU, preventing proper computation.

Common Causes and Solutions

1. Ensure Model and Data Are on the Same Device

The most frequent cause is inconsistent device placement between your model, data, and optimizer. Use .to(device) consistently:

python
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Move model to device
model = YourModel().to(device)

# Move data to device
for i, (input, target) in enumerate(train_loader):
    input, target = input.to(device), target.to(device)

2. Handle Optimizer State Properly

Construct your optimizer after moving the model to the appropriate device:

python
# Correct order: model first, then optimizer
model = YourModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

When resuming training, ensure you move the model to the device before loading the optimizer state:

python
checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)  # Move model before optimizer
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

3. Check Tokenizers and Input Processing

For NLP models, ensure tokenized inputs are moved to the correct device:

python
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("model_name").to(device)
tokenizer = AutoTokenizer.from_pretrained("model_name")

# Don't forget to move tokenized inputs to device
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)

4. Verify Loss Function Parameters

Some loss functions have parameters that need explicit device placement:

python
# Incorrect - tensor remains on CPU
criterion = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([3]))

# Correct - explicitly move to device
criterion = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([3]).to(device))

5. Apple MPS Device Support

For Apple Silicon (M1/M2) users, use the Metal Performance Shaders backend:

python
if torch.backends.mps.is_available():
    device = torch.device("mps")
    model.to(device)

Complete Training Setup Example

python
def setup_training():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Initialize model and move to device
    model = YourModel().to(device)
    
    # Initialize optimizer after model is on device
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
    
    # Initialize criterion (with device-aware parameters if needed)
    criterion = YourCriterion().to(device)
    
    return model, optimizer, criterion, device

def train_epoch(model, optimizer, criterion, train_loader, device):
    model.train()
    for input, target in train_loader:
        # Move data to device
        input, target = input.to(device), target.to(device)
        
        # Training steps
        optimizer.zero_grad()
        output = model(input)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

Checkpoint Saving and Loading Best Practices

python
def save_checkpoint(model, optimizer, epoch, path):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, path)

def load_checkpoint(model, optimizer, path, device):
    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)  # Critical: move model to device first
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    return checkpoint['epoch']

WARNING

Always move your model to the target device before loading the optimizer state. The optimizer expects its parameters to be on the same device they were during saving.

TIP

Use a consistent device management strategy throughout your code. Consider creating a device variable at the beginning and referencing it whenever moving tensors or models.

By following these practices, you can avoid device mismatch errors and ensure smooth training resumption across CPU and GPU environments.