Skip to content

Fix GPU memory leak from loss tensor autograd retention#486

Open
bigximik wants to merge 1 commit intomainfrom
denis/fix-loss-autograd-retention
Open

Fix GPU memory leak from loss tensor autograd retention#486
bigximik wants to merge 1 commit intomainfrom
denis/fix-loss-autograd-retention

Conversation

@bigximik
Copy link
Copy Markdown
Collaborator

Summary

  • Fix: Detach total_loss in head.py and individual loss values in loss.py before appending to context.losses. These scalar tensors retained FunctionBackward grad_fn references from wrap_forward_backward, keeping C++ autograd nodes (and their CUDA tensor references) alive across all microbatches in a training step.
  • Impact: With depth_first_micro_batches >= 128, memory grew ~164 MiB per microbatch, causing OOM on 80 GB H100s. The RL team needs 4K+ microbatches per step.
  • Also: Add per-microbatch memory logging in the schedule runner, and free context.batch entries after the last forward stage.

Root Cause

total_loss returned from _logits_loss_forward_backward is the same tensor object that wrap_forward_backward() wraps with a custom autograd Function. PyTorch attaches a FunctionBackward grad_fn to the returned tensor in-place after Function.forward completes. Since total_loss is stored in context.losses (for logging), each microbatch's
losses dict entry holds a live grad_fn chain back to that microbatch's backward context (ctx.context in the Function), which references CUDA tensors (stage input/output pairs).

These C++ autograd nodes are invisible to Python's gc.get_objects() but consume real GPU memory — approximately 14 C++ CUDA allocations (~164 MiB) per microbatch. With 128 microbatches: 128 × 164 MiB ≈ 21 GiB of leaked autograd state.

The fix is safe because context.losses values are only used for logging — they are reduced to .item() scalars at step end in _reduce_losses.

Test Results

Config Microbatches MB 0 end_alloc Last MB end_alloc Growth Peak Status
Before fix 128 46.6 GiB OOM +0.16 GiB/MB 78+ GiB OOM
After fix (mb128) 128 46.69 GiB 46.72 GiB +0.03 GiB total 60.78 GiB 6 steps, flat memory
After fix (mb1024) 1024 46.69 GiB 46.93 GiB +0.24 GiB total 61.00 GiB Step 1 complete

Tested on Qwen2.5-7B-Instruct, 8× H100, SDP=2, ZeRO-2, 16K sequence length, full recompute.

Test plan

  • Verified mb128 runs 6 training steps with flat memory (62,240 MiB max allocated, identical across steps)
  • Verified mb1024 completes step 1 (62,462 MiB max allocated)
  • Projection: 4K microbatches → ~1 GiB residual growth → well within 80 GiB H100 budget
  • Run existing Fast-LLM test suite

…tention across microbatches

Loss scalars stored in context.losses retained FunctionBackward grad_fn
references from wrap_forward_backward, keeping C++ autograd nodes and
their CUDA tensor references alive across all microbatches. This caused
~164 MiB/microbatch growth, leading to OOM with depth_first_micro_batches>=128.

Fix: .detach() total_loss in head.py and individual losses in loss.py
before appending to the losses dict. These values are only used for
logging (reduced to .item() at step end), so detaching is safe.

Also adds per-microbatch memory logging in the schedule runner and
frees batch data after the last forward stage.
@bigximik bigximik requested a review from jlamypoirier April 16, 2026 09:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant