2020import os
2121
2222from flax import linen as nn
23+ from flax import nnx
2324from flax .linen import partitioning as nn_partitioning
2425from flax .training import train_state
2526
@@ -1030,7 +1031,7 @@ def init_initial_state(model, tx, config, is_training, key):
10301031 return init_decode_state (model .apply , model_vars )
10311032
10321033
1033- def get_abstract_param (model , config ):
1034+ def get_abstract_param (model : nn . Module | nnx . Module , config ):
10341035 """Get abstract model structure (name, shape) without materializing the weights to save memory"""
10351036 with model .mesh , nn_partitioning .axis_rules (config .logical_axis_rules ):
10361037 key = jax .random .PRNGKey (0 )
@@ -1039,14 +1040,17 @@ def get_abstract_param(model, config):
10391040 config .model_name , batch_size = config .micro_batch_size_to_train_on
10401041 )
10411042 audio_shape = mm_processor .get_dummy_audio_shape_for_init (config )
1042- abstract_vars = jax .eval_shape (
1043- model .init ,
1044- {"params" : key , "dropout" : key , "aqt" : key },
1045- jnp .ones (input_shape , dtype = jnp .int32 ),
1046- jnp .ones (input_shape , dtype = jnp .int32 ),
1047- encoder_images = np .ones (image_shape , dtype = jnp .int32 ) if config .use_multimodal else None ,
1048- encoder_audios = np .ones (audio_shape , dtype = jnp .float32 ) if config .use_audio else None ,
1049- )
1043+ if isinstance (model , nn .Module ):
1044+ abstract_vars = jax .eval_shape (
1045+ model .init ,
1046+ {"params" : key , "dropout" : key , "aqt" : key },
1047+ jnp .ones (input_shape , dtype = jnp .int32 ),
1048+ jnp .ones (input_shape , dtype = jnp .int32 ),
1049+ encoder_images = np .ones (image_shape , dtype = jnp .int32 ) if config .use_multimodal else None ,
1050+ encoder_audios = np .ones (audio_shape , dtype = jnp .float32 ) if config .use_audio else None ,
1051+ )
1052+ else : # nnx.Module
1053+ _ , abstract_vars = nnx .split (nnx .eval_shape (lambda : model ))
10501054 return abstract_vars
10511055
10521056
0 commit comments