Skip to content

[Bug Report] Issues with PosEmbed device when used with accelerate #911

@davidquarel

Description

@davidquarel

Describe the bug
PosEmbed uses incorrect device when used with accelerate library

Code example

The following code is a minimal training loop that trains the gpt2 model on randomly generated data.

import torch
from transformer_lens import HookedTransformer
from accelerate import Accelerator
from tqdm import tqdm
import sys
print(f"Python version: {sys.version}")
accelerator = Accelerator()
print(f"Running on device: {accelerator.device}")

model_name = "gpt2"
model = HookedTransformer.from_pretrained(model_name)

tokens = torch.randint(0, model.tokenizer.vocab_size, (64, 100))
random_lengths = torch.randint(90, 100, (64,))
attention_mask = torch.ones_like(tokens)
for i in range(64):
    attention_mask[i, random_lengths[i]:] = 0

dataset = torch.utils.data.TensorDataset(tokens, attention_mask)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=8, shuffle=True)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

model, dataloader, optimizer = accelerator.prepare(model, dataloader, optimizer)

for batch in tqdm(dataloader):
    tokens, attention_mask = batch
    loss = model.forward(tokens, attention_mask=attention_mask, return_type="loss")
    accelerator.backward(loss)
    optimizer.step()
    optimizer.zero_grad()
    
    avg_loss = accelerator.gather(loss).mean().item()
    if accelerator.is_main_process:
        print(f"Loss: {avg_loss}")

When this code is run with python minimal_example.py it works fine.
When it is run with CUDA_VISIBLE_DEVICES=1,2 accelerate launch minimal_example.py I get the following error.

(py310) guest@All:~/david_quarel/advint$ CUDA_VISIBLE_DEVICES=1,2 accelerate launch minimal_inference.py 
The following values were not passed to `accelerate launch` and had defaults used instead:
        `--num_processes` was set to a value of `2`
                More than one GPU was found, enabling multi-GPU training.
                If this was unintended please pass in `--num_processes=1`.
        `--num_machines` was set to a value of `1`
        `--mixed_precision` was set to a value of `'no'`
        `--dynamo_backend` was set to a value of `'no'`
To avoid this warning pass in values for each of the problematic parameters or run `accelerate config`.
Python version: 3.10.17 | packaged by conda-forge | (main, Apr 10 2025, 22:19:12) [GCC 13.3.0]
Python version: 3.10.17 | packaged by conda-forge | (main, Apr 10 2025, 22:19:12) [GCC 13.3.0]
Running on device: cuda:0
Running on device: cuda:1
Loaded pretrained model gpt2 into HookedTransformer
Moving model to device:  cuda
Loaded pretrained model gpt2 into HookedTransformer
Moving model to device:  cuda
  0%|                                                                                                                                  | 0/4 [00:00<?, ?it/s]
[rank1]: Traceback (most recent call last):
[rank1]:   File "/workspace/HOME/guest/david_quarel/advint/minimal_inference.py", line 32, in <module>
[rank1]:     loss = model.forward(tokens, attention_mask=attention_mask, return_type="loss")
[rank1]:   File "/workspace/HOME/guest/.local/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1643, in forward
[rank1]:     else self._run_ddp_forward(*inputs, **kwargs)
[rank1]:   File "/workspace/HOME/guest/.local/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1459, in _run_ddp_forward
[rank1]:     return self.module(*inputs, **kwargs)  # type: ignore[index]
[rank1]:   File "/workspace/HOME/guest/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:   File "/workspace/HOME/guest/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:   File "/workspace/HOME/guest/.conda/envs/py310/lib/python3.10/site-packages/transformer_lens/HookedTransformer.py", line 583, in forward
[rank1]:     ) = self.input_to_embed(
[rank1]:   File "/workspace/HOME/guest/.conda/envs/py310/lib/python3.10/site-packages/transformer_lens/HookedTransformer.py", line 410, in input_to_embed
[rank1]:     residual, shortformer_pos_embed = self.get_residual(
[rank1]:   File "/workspace/HOME/guest/.conda/envs/py310/lib/python3.10/site-packages/transformer_lens/HookedTransformer.py", line 302, in get_residual
[rank1]:     self.pos_embed(tokens, pos_offset, attention_mask)
[rank1]:   File "/workspace/HOME/guest/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:   File "/workspace/HOME/guest/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:   File "/workspace/HOME/guest/.conda/envs/py310/lib/python3.10/site-packages/transformer_lens/components/pos_embed.py", line 58, in forward
[rank1]:     pos_embed = self.W_pos[offset_position_ids]  # [batch, pos, d_model]
[rank1]: RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cuda:1)

System Info
Describe the characteristic of your environment:

  • transformer_lens version 2.15.0 installed via pip
  • What OS are you using? Linux
  • Python version = 3.10.17

Additional context
I'm trying to train an SAE on a transformer_lens model with mutiple GPUs using accelerate which is how I ran into this bug.

Checklist

  • I have checked that there is no similar issue in the repo (required)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions