Skip to content

Add fused_mla_lora_proj config flag to fuse MLA LoRA up-projections#3424

Open
abhinavgoel95 wants to merge 1 commit intoAI-Hypercomputer:mainfrom
abhinavgoel95:abgoel/fused-mla-lora-proj
Open

Add fused_mla_lora_proj config flag to fuse MLA LoRA up-projections#3424
abhinavgoel95 wants to merge 1 commit intoAI-Hypercomputer:mainfrom
abhinavgoel95:abgoel/fused-mla-lora-proj

Conversation

@abhinavgoel95
Copy link
Contributor

@abhinavgoel95 abhinavgoel95 commented Mar 16, 2026

Add a fused_mla_lora_proj boolean config (default False) that fuses the separate wq_a (emb→q_lora_rank) and wkv_a (emb→kv_lora_rank+rope_head_dim) MLA LoRA up-projection matmuls into a single wq_kv_a matmul (emb→q_lora_rank+kv_lora_rank+rope_head_dim), followed by a split.

This halves the number of matmul kernel launches for the LoRA up-projection step. The flag is modelled after the existing fused_qkv config and requires attention_type='mla' and q_lora_rank > 0.

Note: wq_kv_a uses a different weight name than wq_a/wkv_a, so checkpoints are not cross-compatible between fused and unfused modes.

Ends up with 3% speedup in MLA computation on GB200 GPUs.

Description

Add a fused_mla_lora_proj boolean config flag (default False) that fuses
the two separate MLA LoRA up-projection matmuls — wq_a (emb → q_lora_rank)
and wkv_a (emb → kv_lora_rank + rope_head_dim) — into a single wq_kv_a
matmul (emb → q_lora_rank + kv_lora_rank + rope_head_dim), followed by a
jnp.split. This halves the number of matmul kernel launches for the LoRA
up-projection step when q_lora_rank > 0, which is the common case for MLA
models like DeepSeek-V3.

The flag is modelled after the existing fused_qkv config:

  • configs/base.yml: default False
  • configs/types.py: field definition + validation (requires attention_type='mla'
    and q_lora_rank > 0)
  • layers/attention_mla.py: new elif self.config.fused_mla_lora_proj branch in
    _init_projections initialises wq_kv_a instead of wq_a/wkv_a; fused
    dispatch in __call__; passthrough guards in mla_query_projection and
    mla_kv_projection

Shortcoming / compatibility note: wq_kv_a uses a different weight name than
wq_a/wkv_a, so checkpoints are not cross-compatible between
fused_mla_lora_proj=True and False. A future improvement would be a
checkpoint conversion utility.

Tests

Tested via config validation:

  • fused_mla_lora_proj=True, q_lora_rank=0 raises ValueError
  • fused_mla_lora_proj=True, attention_type='global' raises ValueError

End-to-end numerical equivalence between fused and unfused paths with a
DeepSeek-V3 config (q_lora_rank=1536, kv_lora_rank=512, qk_rope_head_dim=64) has been verified.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@abhinavgoel95 abhinavgoel95 force-pushed the abgoel/fused-mla-lora-proj branch from 8e0fcfb to be3169b Compare March 16, 2026 18:51
@codecov
Copy link

codecov bot commented Mar 16, 2026

Codecov Report

❌ Patch coverage is 82.60870% with 4 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/layers/attention_mla.py 82.60% 2 Missing and 2 partials ⚠️

📢 Thoughts on this report? Let us know!

Copy link
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the change! Could you please add one unit test in this file, to compare the outputs are same with and without this config?

@abhinavgoel95 abhinavgoel95 force-pushed the abgoel/fused-mla-lora-proj branch from 89a0056 to 14328b4 Compare March 18, 2026 19:43
Fuses the Q and KV LoRA up-projections (wq_a + wkv_a) into a single
matmul (wq_kv_a: emb → q_lora_rank + kv_lora_rank + rope_head_dim),
halving the number of kernel launches for the LoRA up-projection step.

Enabled via fused_mla_lora_proj: True (requires q_lora_rank > 0 and
attention_type=mla). Modeled after the existing fused_qkv flag.

Includes a unit test verifying that fused and unfused paths produce
numerically identical outputs given equivalent weights.
@abhinavgoel95 abhinavgoel95 force-pushed the abgoel/fused-mla-lora-proj branch from 14328b4 to 3b46b3c Compare March 19, 2026 19:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants