Skip to content

Add MoE and MLA remat policies#3414

Open
abhinavgoel95 wants to merge 2 commits intoAI-Hypercomputer:mainfrom
abhinavgoel95:abgoel/add-moe-mla-remat-policies
Open

Add MoE and MLA remat policies#3414
abhinavgoel95 wants to merge 2 commits intoAI-Hypercomputer:mainfrom
abhinavgoel95:abgoel/add-moe-mla-remat-policies

Conversation

@abhinavgoel95
Copy link
Contributor

@abhinavgoel95 abhinavgoel95 commented Mar 13, 2026

  • Added moe_mlpwi, moe_mlpwi_0, moe_mlpwi_1, moe_mlpwo for MoE layers
  • Added query_wa_proj, kv_wa_proj for MLA layers
  • Updated base.yml, types.py, and pyconfig_deprecated.py

Description

This PR adds rematerialization policy support for Mixture of Experts (MoE) and Multi-head Latent Attention (MLA) layer tensors.

Previously, MaxText only supported remat policies for standard dense layer tensors. This prevented fine-grained memory optimization for MoE models (like Mixtral, DeepSeek V3) and models using MLA architecture (like DeepSeek V3).

This change adds six new configurable remat tensors:

  • MoE tensors: moe_mlpwi, moe_mlpwi_0, moe_mlpwi_1, moe_mlpwo
  • MLA tensors: query_wa_proj, kv_wa_proj

Users can now configure these tensors with device, offload, or remat policies in their config files, enabling better memory management for large MoE models (e.g., DeepSeek V3 671B).

Files modified:

  • src/maxtext/configs/base.yml - Added default 'remat' values
  • src/maxtext/configs/types.py - Added Field definitions with descriptions
  • src/maxtext/configs/pyconfig_deprecated.py - Added to validation whitelist

All new tensors default to 'remat', maintaining backward compatibility.

Tests

Tested with DeepSeek V3 671B (41 layers) on 128 GPUs with various remat configurations:

  • Baseline with all tensors set to remat - ✅ Works
  • Custom policies with selective offload and device placement - ✅ Works
  • Verified backward compatibility with Llama models (no regression)

Example config usage:

moe_mlpwi: 'offload'
query_wa_proj: 'device'

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in https://maxtext.readthedocs.io/en/latest/development.html#adding-new-documentation-files.

- Added moe_mlpwi, moe_mlpwi_0, moe_mlpwi_1, moe_mlpwo for MoE layers
- Added query_wa_proj, kv_wa_proj for MLA layers
- Updated base.yml, types.py, and pyconfig_deprecated.py
Wire up the remat config keys added in 717ddf5 to actual checkpoint_name
call sites in the layer code:
- moe.py: rename mlpwi_0/1 and mlpwo -> moe_mlpwi_0/1 and moe_mlpwo (9 sites)
- attention_mla.py: add query_wa_proj after wq_a and kv_wa_proj after wkv_a
layer_w0 = jax.lax.psum(layer_w0, "tensor_transpose")
if self.config.mlp_bias:
layer_w0 = layer_w0 + w0_bias
layer_w0 = adc.checkpoint_name(layer_w0, "mlpwi_0")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Heads up: this might affect all legacy TPU recipes/performance for MoE models. We should make an announcement after it gets merged. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants