Add MoE and MLA remat policies#3414
Open
abhinavgoel95 wants to merge 2 commits intoAI-Hypercomputer:mainfrom
Open
Add MoE and MLA remat policies#3414abhinavgoel95 wants to merge 2 commits intoAI-Hypercomputer:mainfrom
abhinavgoel95 wants to merge 2 commits intoAI-Hypercomputer:mainfrom
Conversation
- Added moe_mlpwi, moe_mlpwi_0, moe_mlpwi_1, moe_mlpwo for MoE layers - Added query_wa_proj, kv_wa_proj for MLA layers - Updated base.yml, types.py, and pyconfig_deprecated.py
Wire up the remat config keys added in 717ddf5 to actual checkpoint_name call sites in the layer code: - moe.py: rename mlpwi_0/1 and mlpwo -> moe_mlpwi_0/1 and moe_mlpwo (9 sites) - attention_mla.py: add query_wa_proj after wq_a and kv_wa_proj after wkv_a
RissyRan
reviewed
Mar 14, 2026
| layer_w0 = jax.lax.psum(layer_w0, "tensor_transpose") | ||
| if self.config.mlp_bias: | ||
| layer_w0 = layer_w0 + w0_bias | ||
| layer_w0 = adc.checkpoint_name(layer_w0, "mlpwi_0") |
Collaborator
There was a problem hiding this comment.
Heads up: this might affect all legacy TPU recipes/performance for MoE models. We should make an announcement after it gets merged. Thanks!
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
This PR adds rematerialization policy support for Mixture of Experts (MoE) and Multi-head Latent Attention (MLA) layer tensors.
Previously, MaxText only supported remat policies for standard dense layer tensors. This prevented fine-grained memory optimization for MoE models (like Mixtral, DeepSeek V3) and models using MLA architecture (like DeepSeek V3).
This change adds six new configurable remat tensors:
moe_mlpwi,moe_mlpwi_0,moe_mlpwi_1,moe_mlpwoquery_wa_proj,kv_wa_projUsers can now configure these tensors with
device,offload, orrematpolicies in their config files, enabling better memory management for large MoE models (e.g., DeepSeek V3 671B).Files modified:
src/maxtext/configs/base.yml- Added default'remat'valuessrc/maxtext/configs/types.py- Added Field definitions with descriptionssrc/maxtext/configs/pyconfig_deprecated.py- Added to validation whitelistAll new tensors default to
'remat', maintaining backward compatibility.Tests
Tested with DeepSeek V3 671B (41 layers) on 128 GPUs with various remat configurations:
remat- ✅ Worksoffloadanddeviceplacement - ✅ WorksExample config usage:
Checklist
Before submitting this PR, please make sure (put X in square brackets):