Skip to content

Enable training strategy for Indexer#3415

Open
RissyRan wants to merge 1 commit intomainfrom
indexer_train_strategy
Open

Enable training strategy for Indexer#3415
RissyRan wants to merge 1 commit intomainfrom
indexer_train_strategy

Conversation

@RissyRan
Copy link
Collaborator

@RissyRan RissyRan commented Mar 14, 2026

Description

Enable selective parameter training strategy for DeeSeek V3.2 Indexer - paper

  • Dense warm up stage:
    • Add trainable_parameters_mask flag, allowing specific parameters to be targeted for training while freezing the rest of the model.
    • Add TrainableParametersMaskTest unit tests for validation.
  • Sparse training stage:
    • Add sparse_indexer_training flag to indicate Dense Warm-up stage or Sparse Training stage for DS v3.2.
    • Add test_indexer_gradients unit test to verify proper gradient isolation.
  • Renaming flags to avoid confusion
    • use_sparse_indexer --> use_indexer; index_head_dim --> indexer_head_dim; index_n_heads --> indexer_n_heads, and index_topk --> indexer_topk

Tests

  • Expect added unit tests are all green
  • End-to-end functional - logs
  • Sanity check deepseek32_vs_reference_test - link, same as b/491486716

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link

codecov bot commented Mar 14, 2026

@github-actions
Copy link

🤖 Hi @RissyRan, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@RissyRan RissyRan changed the title Enable selective parameter training strategy Enable training strategy for Indexer Mar 14, 2026
@RissyRan RissyRan force-pushed the indexer_train_strategy branch from 85e4e0d to b0f353b Compare March 14, 2026 02:20
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This PR enables the selective parameter training strategy (dense warm-up and sparse training stages) for the DeepSeek V3.2 Indexer. It refactors parameter freezing flags and adds tests to verify proper isolation of indexer gradients from the rest of the model.

🔍 General Feedback

  • Memory Optimization in Selective Training: The current implementation of optimizer masking computes and stores Adam state parameters for the entire model before zeroing out the updates. I've suggested an explicit mapping with optax.multi_transform to avoid allocating massive memory blocks for frozen parameter states, which is critical for 671B model scaling.
  • Gradient Isolation in KL Divergence: I left an inline comment pointing out a gradient leak when calculating the KL divergence in calculate_indexer_loss. Ensure jax.lax.stop_gradient is applied to the target attention_probs distribution, so that the main model's queries and keys do not get updated by the indexer's loss.

@RissyRan RissyRan force-pushed the indexer_train_strategy branch 3 times, most recently from 0d0a638 to 4ec8a1e Compare March 14, 2026 03:16
@RissyRan RissyRan force-pushed the indexer_train_strategy branch from 4ec8a1e to 906f12f Compare March 14, 2026 03:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants