Skip to content

Commit 82365f7

Browse files
author
Charles Li
committed
Support muon
1 parent 9f9629c commit 82365f7

2 files changed

Lines changed: 16 additions & 12 deletions

File tree

src/maxtext/trainers/pre_train/nnx_train.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -464,8 +464,8 @@ def eval_step(
464464
"evaluation/mtp_acceptance_rate_percent": mtp_acceptance_rate,
465465
},
466466
}
467-
# if config.use_dpo:
468-
# metrics["scalar"]["evaluation/dpo_reward_accuracy"] = aux["reward_accuracy"]
467+
if config.use_dpo:
468+
metrics["scalar"]["evaluation/dpo_reward_accuracy"] = aux["reward_accuracy"]
469469

470470
return metrics
471471

@@ -497,7 +497,7 @@ def _create_and_shard_optimizer(model: nnx.Module, config, mesh: Mesh):
497497
learning_rate_schedule: Learning-rate schedule function.
498498
"""
499499
learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config)
500-
tx = optimizers.get_optimizer(config, learning_rate_schedule)
500+
tx = optimizers.get_optimizer(config, learning_rate_schedule, model)
501501
# NNX 0.11+: wrt is mandatory; optimizer does not store a model reference.
502502
optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param)
503503

src/maxtext/utils/maxtext_utils.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import os
2121

2222
from flax import linen as nn
23+
from flax import nnx
2324
from flax.linen import partitioning as nn_partitioning
2425
from flax.training import train_state
2526

@@ -1030,7 +1031,7 @@ def init_initial_state(model, tx, config, is_training, key):
10301031
return init_decode_state(model.apply, model_vars)
10311032

10321033

1033-
def get_abstract_param(model, config):
1034+
def get_abstract_param(model: nn.Module | nnx.Module, config):
10341035
"""Get abstract model structure (name, shape) without materializing the weights to save memory"""
10351036
with model.mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
10361037
key = jax.random.PRNGKey(0)
@@ -1039,14 +1040,17 @@ def get_abstract_param(model, config):
10391040
config.model_name, batch_size=config.micro_batch_size_to_train_on
10401041
)
10411042
audio_shape = mm_processor.get_dummy_audio_shape_for_init(config)
1042-
abstract_vars = jax.eval_shape(
1043-
model.init,
1044-
{"params": key, "dropout": key, "aqt": key},
1045-
jnp.ones(input_shape, dtype=jnp.int32),
1046-
jnp.ones(input_shape, dtype=jnp.int32),
1047-
encoder_images=np.ones(image_shape, dtype=jnp.int32) if config.use_multimodal else None,
1048-
encoder_audios=np.ones(audio_shape, dtype=jnp.float32) if config.use_audio else None,
1049-
)
1043+
if isinstance(model, nn.Module):
1044+
abstract_vars = jax.eval_shape(
1045+
model.init,
1046+
{"params": key, "dropout": key, "aqt": key},
1047+
jnp.ones(input_shape, dtype=jnp.int32),
1048+
jnp.ones(input_shape, dtype=jnp.int32),
1049+
encoder_images=np.ones(image_shape, dtype=jnp.int32) if config.use_multimodal else None,
1050+
encoder_audios=np.ones(audio_shape, dtype=jnp.float32) if config.use_audio else None,
1051+
)
1052+
else: # nnx.Module
1053+
_, abstract_vars = nnx.split(nnx.eval_shape(lambda: model))
10501054
return abstract_vars
10511055

10521056

0 commit comments

Comments
 (0)