@@ -345,7 +345,8 @@ class Checkpointing(BaseModel):
345345 description = "If True, enables checkpointing from remote TPU VMs instead of head node on pathways." ,
346346 )
347347 enable_autocheckpoint : bool = Field (
348- False , description = "If True, enables autocheckpoint or preemption induced checkpointing."
348+ False ,
349+ description = "If True, enables autocheckpoint or preemption induced checkpointing." ,
349350 )
350351
351352
@@ -495,9 +496,13 @@ class ModelArchitecture(BaseModel):
495496 )
496497 fused_mlp : bool = Field (False , description = "If supported, fuse the MLP layers." )
497498 qk_norm_with_scale : bool = Field (
498- True , description = "Whether to apply scale on query and key normalizations (default True)."
499+ True ,
500+ description = "Whether to apply scale on query and key normalizations (default True)." ,
501+ )
502+ v_norm_with_scale : bool = Field (
503+ True ,
504+ description = "Whether to apply scale on value normalization (default True)." ,
499505 )
500- v_norm_with_scale : bool = Field (True , description = "Whether to apply scale on value normalization (default True)." )
501506
502507
503508class MTP (BaseModel ):
@@ -542,9 +547,13 @@ class Attention(BaseModel):
542547 "global" , description = "The variant of attention to use."
543548 )
544549 share_kv_projections : bool = Field (
545- False , description = "If True, for global attention, Key and Value projections share the same weights."
550+ False ,
551+ description = "If True, for global attention, Key and Value projections share the same weights." ,
552+ )
553+ global_num_kv_heads : int = Field (
554+ 0 ,
555+ description = "If greater than 0, sets the number of KV heads for global attention." ,
546556 )
547- global_num_kv_heads : int = Field (0 , description = "If greater than 0, sets the number of KV heads for global attention." )
548557 attention_sink : bool = Field (False , description = "If True, enables attention sinks." )
549558 float32_qk_product : bool = Field (False , description = "In dot-product attention, cast query-key product to fp32." )
550559 float32_logits : bool = Field (
@@ -672,14 +681,18 @@ class MoEGeneral(BaseModel):
672681 num_experts : PositiveInt = Field (1 , description = "The total number of experts in each MoE layer." )
673682 num_experts_per_tok : PositiveInt = Field (1 , description = "The number of experts to route each token to." )
674683 capacity_factor : float = Field (- 1.0 , description = "Expert capacity factor. If < 0, no token dropping." )
675- ragged_buffer_factor : float = Field (- 1.0 , description = "Ragged buffer factor. If < 0, ragged buffer is worst case size." )
684+ ragged_buffer_factor : float = Field (
685+ - 1.0 ,
686+ description = "Ragged buffer factor. If < 0, ragged buffer is worst case size." ,
687+ )
676688 moe_expert_input_dim : int = Field (
677689 - 1 ,
678690 description = "Dimension of tokens entering the MoE layer. If < 0, defaults to emb_dim." ,
679691 )
680692 base_moe_mlp_dim : int = Field (- 1 , description = "Intermediate dimension at MoE layer." )
681693 padded_base_moe_mlp_dim : Optional [int ] = Field (
682- None , description = "Padded intermediate dimension at MoE layer for efficient GMM_v2 kernel execution."
694+ None ,
695+ description = "Padded intermediate dimension at MoE layer for efficient GMM_v2 kernel execution." ,
683696 )
684697 load_balance_loss_weight : NonNegativeFloat = Field (0.0 , description = "Weight for the load balancing auxiliary loss." )
685698 use_custom_sort_vjp : bool = Field (
@@ -860,7 +873,8 @@ class HardwareAndMesh(BaseModel):
860873 )
861874 custom_mesh : str = Field ("" , description = "Available options: ['hybrid_ring_64x4', 'hybrid_ring_32x8']" )
862875 custom_mesh_and_rule : CustomRule = Field (
863- CustomRule .DEFAULT , description = "Customized mesh and logical rules for granularity."
876+ CustomRule .DEFAULT ,
877+ description = "Customized mesh and logical rules for granularity." ,
864878 )
865879 allow_split_physical_axes : bool = Field (False , description = "Allow splitting physical axes for device mesh creation." )
866880 enable_nnx : bool = Field (False , description = "Whether to use NNX for model definition." )
@@ -869,7 +883,8 @@ class HardwareAndMesh(BaseModel):
869883 pure_nnx_decoder : bool = Field (False , description = "Whether to enable pure NNX decoder." )
870884 pure_nnx : bool = Field (False , description = "Whether to enable pure NNX mode." )
871885 remove_size_one_mesh_axis_from_type : bool = Field (
872- True , description = "Whether to remove size one mesh axis from type through jax.config."
886+ True ,
887+ description = "Whether to remove size one mesh axis from type through jax.config." ,
873888 )
874889
875890
@@ -890,7 +905,10 @@ class LayoutAndSharding(BaseModel):
890905 description = "Allowed percentage of non-sharded parameters." ,
891906 )
892907 shard_optimizer_over_data : bool = Field (False , description = "Enable ZeRO-1 optimizer sharding over the data axis." )
893- internal_compile : bool = Field (False , description = "Use internal_compile to bypass open-source topology mappings." )
908+ internal_compile : bool = Field (
909+ False ,
910+ description = "Use internal_compile to bypass open-source topology mappings." ,
911+ )
894912 internal_compile_num_devices : int = Field (- 1 , description = "Number of devices when using internal_compile." )
895913 compile_xla_flags : str = Field ("" , description = "Compiler options for compilation only." )
896914
@@ -937,7 +955,8 @@ class PipelineParallelism(BaseModel):
937955 """Configuration for pipeline parallelism."""
938956
939957 pipeline_fsdp_ag_per_repeat : bool = Field (
940- False , description = "Enable weight prefetching for circular pipeline parallelism."
958+ False ,
959+ description = "Enable weight prefetching for circular pipeline parallelism." ,
941960 )
942961 num_layers_per_pipeline_stage : int = Field (1 , description = "Number of layers to place on each pipeline stage." )
943962 num_pipeline_repeats : int = Field (
@@ -1046,7 +1065,8 @@ class Tokenizer(BaseModel):
10461065 use_chat_template : bool = Field (False , description = "Whether to use the chat template for tokenization." )
10471066 chat_template_path : str = Field ("" , description = "Path to chat template json file." )
10481067 chat_template : str = Field (
1049- "" , description = "Chat template to use with HF tokenizers. It should be a valid Jinja2-formatted template."
1068+ "" ,
1069+ description = "Chat template to use with HF tokenizers. It should be a valid Jinja2-formatted template." ,
10501070 )
10511071 tokenize_train_data : bool = Field (True , description = "If False, assumes the training dataset is pre-tokenized." )
10521072 tokenize_eval_data : bool = Field (True , description = "If False, assumes the evaluation dataset is pre-tokenized." )
@@ -1138,7 +1158,8 @@ class GrainDataset(BaseModel):
11381158 description = "Path to a JSON file specifying the mixture weights for Grain training data." ,
11391159 )
11401160 grain_file_type : str = Field (
1141- "arrayrecord" , description = "File type for Grain data. Supported: arrayrecord, tfrecord, parquet."
1161+ "arrayrecord" ,
1162+ description = "File type for Grain data. Supported: arrayrecord, tfrecord, parquet." ,
11421163 )
11431164 grain_use_elastic_iterator : bool = Field (
11441165 False ,
@@ -1179,7 +1200,10 @@ class OlmoGrainDataset(BaseModel):
11791200 ``data_shuffle_seed``); only OLMo-specific fields are listed here.
11801201 """
11811202
1182- olmo_index_path : PathStr = Field ("" , description = "Path or gs:// URI to the JSON index from build_olmo_npy_index.py." )
1203+ olmo_index_path : PathStr = Field (
1204+ "" ,
1205+ description = "Path or gs:// URI to the JSON index from build_olmo_npy_index.py." ,
1206+ )
11831207 olmo_path_remap_from : PathStr = Field (
11841208 "" ,
11851209 description = "If set, rewrite index file paths starting with this prefix to olmo_path_remap_to." ,
@@ -1229,32 +1253,39 @@ class Distillation(BaseModel):
12291253
12301254 # --- Offline Distillation Field ---
12311255 offline_data_dir : Optional [str ] = Field (
1232- None , description = "GCS or local path to the pre-generated ArrayRecord teacher data."
1256+ None ,
1257+ description = "GCS or local path to the pre-generated ArrayRecord teacher data." ,
12331258 )
12341259
12351260 # --- Loss Params ---
12361261 distill_alpha : float = Field (0.5 , description = "Weight for the distillation loss component." )
12371262 distill_temperature : float = Field (1.0 , description = "Temperature for distillation softening." )
12381263 distill_beta : float = Field (0.0 , description = "Weight for the feature loss component. Use 0.0 to disable" )
12391264 distill_feature_loss_type : Literal ["cosine" , "l2" ] = Field (
1240- "cosine" , description = "The type of loss to use for feature distillation ('cosine' or 'l2')."
1265+ "cosine" ,
1266+ description = "The type of loss to use for feature distillation ('cosine' or 'l2')." ,
12411267 )
12421268 distill_layer_indices : None | list = Field (None , description = "Feature indices for feature loss." )
12431269 distill_alpha_end : Optional [float ] = Field (None , description = "Target alpha at end of training. None keeps alpha fixed." )
12441270 distill_alpha_schedule : Literal ["constant" , "linear" , "cosine" ] = Field (
1245- "constant" , description = "Schedule type for alpha annealing ('constant', 'linear', or 'cosine')."
1271+ "constant" ,
1272+ description = "Schedule type for alpha annealing ('constant', 'linear', or 'cosine')." ,
12461273 )
12471274 distill_temperature_end : Optional [float ] = Field (
1248- None , description = "Target temperature at end of training. None keeps temperature fixed."
1275+ None ,
1276+ description = "Target temperature at end of training. None keeps temperature fixed." ,
12491277 )
12501278 distill_temperature_schedule : Literal ["constant" , "linear" , "cosine" ] = Field (
1251- "constant" , description = "Schedule type for temperature annealing ('constant', 'linear', or 'cosine')."
1279+ "constant" ,
1280+ description = "Schedule type for temperature annealing ('constant', 'linear', or 'cosine')." ,
12521281 )
12531282 distill_beta_end : Optional [float ] = Field (
1254- None , description = "Target beta_feature at end of training. None keeps beta fixed."
1283+ None ,
1284+ description = "Target beta_feature at end of training. None keeps beta fixed." ,
12551285 )
12561286 distill_beta_schedule : Literal ["constant" , "linear" , "cosine" ] = Field (
1257- "constant" , description = "Schedule type for beta annealing ('constant', 'linear', or 'cosine')."
1287+ "constant" ,
1288+ description = "Schedule type for beta annealing ('constant', 'linear', or 'cosine')." ,
12581289 )
12591290
12601291 # --- Learn to init related parameters --
@@ -1277,11 +1308,13 @@ class Distillation(BaseModel):
12771308 )
12781309
12791310 attn_module_name : Optional [str ] = Field (
1280- None , description = "Attention nnx module attribute name to augment with LTI logic"
1311+ None ,
1312+ description = "Attention nnx module attribute name to augment with LTI logic" ,
12811313 )
12821314
12831315 lti_layer_indices : Optional [list [int ]] = Field (
1284- None , description = "List of layer indices to apply LTI modifications. If None, applied to all layers."
1316+ None ,
1317+ description = "List of layer indices to apply LTI modifications. If None, applied to all layers." ,
12851318 )
12861319 # ---------------------------------------
12871320
@@ -1328,6 +1361,10 @@ class ManifoldConstrainedHyperConnections(BaseModel):
13281361
13291362 mhc_expansion_rate : PositiveInt = Field (1 , description = "The number of parallel streams in Hyper Connection." )
13301363 sinkhorn_iterations : PositiveInt = Field (20 , description = "The number of iterations for the Sinkhorn-Knopp algorithm." )
1364+ enable_mhc_k4_shortcut : bool = Field (
1365+ True ,
1366+ description = "Whether to enable the permutation-based convex combination shortcut when mhc_expansion_rate is 4." ,
1367+ )
13311368
13321369
13331370class DilocoParams (BaseModel ):
@@ -1344,10 +1381,12 @@ class Optimizer(BaseModel):
13441381
13451382 opt_type : OptimizerType = Field (OptimizerType .ADAMW , description = "The type of optimizer to use." )
13461383 skip_step_on_spikes : bool = Field (
1347- False , description = "If True, skip the training step when loss or gradient spike is detected."
1384+ False ,
1385+ description = "If True, skip the training step when loss or gradient spike is detected." ,
13481386 )
13491387 skip_step_interval : PositiveInt = Field (
1350- 128 , description = "The rolling interval to calculate the mean and standard deviation."
1388+ 128 ,
1389+ description = "The rolling interval to calculate the mean and standard deviation." ,
13511390 )
13521391 skip_step_scaling_factor : float = Field (6.0 , description = "The scaling factor to determine if a spike occurred." )
13531392 gradient_accumulation_steps : PositiveInt = Field (
@@ -1616,7 +1655,8 @@ class Profiling(BaseModel):
16161655 tpu_num_chips_to_profile_per_task : int = Field (1 , description = "Specifies the number of TPU chips to profile per task." )
16171656 tpu_num_sparse_cores_to_trace : int = Field (2 , description = "Specifies the number of TPU chips to profile per task." )
16181657 tpu_num_sparse_core_tiles_to_trace : int = Field (
1619- 1 , description = "Specifies the number of tiles within each sparse core to trace on the TPU."
1658+ 1 ,
1659+ description = "Specifies the number of tiles within each sparse core to trace on the TPU." ,
16201660 )
16211661 xprof_tpu_power_trace_level : XProfTPUPowerTraceMode = Field (
16221662 XProfTPUPowerTraceMode .POWER_TRACE_NONE ,
@@ -1796,7 +1836,10 @@ class VisionTower(BaseModel):
17961836 temporal_patch_size_for_vit : int = Field (2 , description = "Temporal patch size for video inputs." )
17971837 num_position_embeddings_for_vit : int = Field (1024 , description = "Number of position embeddings for ViT." )
17981838 deepstack_visual_indexes_for_vit : list [int ] = Field ([], description = "Layer indices to extract deep visual features." )
1799- vision_output_length : int = Field (- 1 , description = "The output length (number of soft tokens) from the vision encoder." )
1839+ vision_output_length : int = Field (
1840+ - 1 ,
1841+ description = "The output length (number of soft tokens) from the vision encoder." ,
1842+ )
18001843
18011844
18021845class VisionProjector (BaseModel ):
@@ -1900,18 +1943,28 @@ class RL(BaseModel):
19001943 grpo_epsilon : float = Field (0.2 , description = "Epsilon value for clipping in the GRPO loss." )
19011944 loss_algo : Literal ["grpo" , "gspo-token" ] = Field ("grpo" , description = "Loss algorithm, i.e., 'grpo' or 'gspo-token'." )
19021945 use_agentic_rollout : bool = Field (
1903- False , description = "If True, uses the asynchronous AgenticGRPOLearner for online vLLM rollouts."
1946+ False ,
1947+ description = "If True, uses the asynchronous AgenticGRPOLearner for online vLLM rollouts." ,
1948+ )
1949+ max_concurrency : int = Field (
1950+ 256 ,
1951+ description = "Maximum number of concurrent rollout requests (agentic rollout only)." ,
19041952 )
1905- max_concurrency : int = Field (256 , description = "Maximum number of concurrent rollout requests (agentic rollout only)." )
19061953 off_policy_steps : int = Field (
1907- 0 , description = "Number of off-policy steps tolerated before requiring a policy update (agentic only)."
1954+ 0 ,
1955+ description = "Number of off-policy steps tolerated before requiring a policy update (agentic only)." ,
1956+ )
1957+ system_prompt : str = Field (
1958+ "" ,
1959+ description = "System prompt injected into the agent at rollout time (agentic only)." ,
19081960 )
1909- system_prompt : str = Field ("" , description = "System prompt injected into the agent at rollout time (agentic only)." )
19101961 degenerate_group_masking : bool = Field (
1911- True , description = "Mask degenerate groups (all-zero advantages) from contributing to loss (agentic only)."
1962+ True ,
1963+ description = "Mask degenerate groups (all-zero advantages) from contributing to loss (agentic only)." ,
19121964 )
19131965 epsilon_high : Optional [float ] = Field (
1914- None , description = "Upper-bound clipping epsilon for GRPO loss. Defaults to epsilon when None (agentic only)."
1966+ None ,
1967+ description = "Upper-bound clipping epsilon for GRPO loss. Defaults to epsilon when None (agentic only)." ,
19151968 )
19161969 reshard_chunk_size : Optional [int ] = Field (
19171970 None ,
@@ -2435,7 +2488,11 @@ def validate_and_set_hlo_dump_defaults():
24352488 )
24362489 for param_name , schedule , end_value in [
24372490 ("distill_alpha" , self .distill_alpha_schedule , self .distill_alpha_end ),
2438- ("distill_temperature" , self .distill_temperature_schedule , self .distill_temperature_end ),
2491+ (
2492+ "distill_temperature" ,
2493+ self .distill_temperature_schedule ,
2494+ self .distill_temperature_end ,
2495+ ),
24392496 ("distill_beta" , self .distill_beta_schedule , self .distill_beta_end ),
24402497 ]:
24412498 if schedule != "constant" and end_value is None :
@@ -2948,7 +3005,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
29483005 self .use_grpo = False
29493006
29503007 if self .use_batch_split_schedule :
2951- if self .quantization and not self .quantization = = "fp8_full" :
3008+ if self .quantization and self .quantization ! = "fp8_full" :
29523009 raise ValueError ("Batch split quantization only supports `quantization=fp8_full`" )
29533010
29543011 if self .opt_type == "muon" and self .decoder_block not in [
0 commit comments