@@ -506,7 +506,10 @@ class ModelArchitecture(BaseModel):
506506 True ,
507507 description = "Whether to apply scale on query and key normalizations (default True)." ,
508508 )
509- v_norm_with_scale : bool = Field (True , description = "Whether to apply scale on value normalization (default True)." )
509+ v_norm_with_scale : bool = Field (
510+ True ,
511+ description = "Whether to apply scale on value normalization (default True)." ,
512+ )
510513
511514
512515class MTP (BaseModel ):
@@ -685,14 +688,18 @@ class MoEGeneral(BaseModel):
685688 num_experts : PositiveInt = Field (1 , description = "The total number of experts in each MoE layer." )
686689 num_experts_per_tok : PositiveInt = Field (1 , description = "The number of experts to route each token to." )
687690 capacity_factor : float = Field (- 1.0 , description = "Expert capacity factor. If < 0, no token dropping." )
688- ragged_buffer_factor : float = Field (- 1.0 , description = "Ragged buffer factor. If < 0, ragged buffer is worst case size." )
691+ ragged_buffer_factor : float = Field (
692+ - 1.0 ,
693+ description = "Ragged buffer factor. If < 0, ragged buffer is worst case size." ,
694+ )
689695 moe_expert_input_dim : int = Field (
690696 - 1 ,
691697 description = "Dimension of tokens entering the MoE layer. If < 0, defaults to emb_dim." ,
692698 )
693699 base_moe_mlp_dim : int = Field (- 1 , description = "Intermediate dimension at MoE layer." )
694700 padded_base_moe_mlp_dim : Optional [int ] = Field (
695- None , description = "Padded intermediate dimension at MoE layer for efficient GMM_v2 kernel execution."
701+ None ,
702+ description = "Padded intermediate dimension at MoE layer for efficient GMM_v2 kernel execution." ,
696703 )
697704 load_balance_loss_weight : NonNegativeFloat = Field (0.0 , description = "Weight for the load balancing auxiliary loss." )
698705 use_custom_sort_vjp : bool = Field (
@@ -873,7 +880,8 @@ class HardwareAndMesh(BaseModel):
873880 )
874881 custom_mesh : str = Field ("" , description = "Available options: ['hybrid_ring_64x4', 'hybrid_ring_32x8']" )
875882 custom_mesh_and_rule : CustomRule = Field (
876- CustomRule .DEFAULT , description = "Customized mesh and logical rules for granularity."
883+ CustomRule .DEFAULT ,
884+ description = "Customized mesh and logical rules for granularity." ,
877885 )
878886 allow_split_physical_axes : bool = Field (False , description = "Allow splitting physical axes for device mesh creation." )
879887 enable_nnx : bool = Field (False , description = "Whether to use NNX for model definition." )
@@ -882,7 +890,8 @@ class HardwareAndMesh(BaseModel):
882890 pure_nnx_decoder : bool = Field (False , description = "Whether to enable pure NNX decoder." )
883891 pure_nnx : bool = Field (False , description = "Whether to enable pure NNX mode." )
884892 remove_size_one_mesh_axis_from_type : bool = Field (
885- True , description = "Whether to remove size one mesh axis from type through jax.config."
893+ True ,
894+ description = "Whether to remove size one mesh axis from type through jax.config." ,
886895 )
887896
888897
@@ -903,7 +912,10 @@ class LayoutAndSharding(BaseModel):
903912 description = "Allowed percentage of non-sharded parameters." ,
904913 )
905914 shard_optimizer_over_data : bool = Field (False , description = "Enable ZeRO-1 optimizer sharding over the data axis." )
906- internal_compile : bool = Field (False , description = "Use internal_compile to bypass open-source topology mappings." )
915+ internal_compile : bool = Field (
916+ False ,
917+ description = "Use internal_compile to bypass open-source topology mappings." ,
918+ )
907919 internal_compile_num_devices : int = Field (- 1 , description = "Number of devices when using internal_compile." )
908920 compile_xla_flags : str = Field ("" , description = "Compiler options for compilation only." )
909921
@@ -950,7 +962,8 @@ class PipelineParallelism(BaseModel):
950962 """Configuration for pipeline parallelism."""
951963
952964 pipeline_fsdp_ag_per_repeat : bool = Field (
953- False , description = "Enable weight prefetching for circular pipeline parallelism."
965+ False ,
966+ description = "Enable weight prefetching for circular pipeline parallelism." ,
954967 )
955968 num_layers_per_pipeline_stage : int = Field (1 , description = "Number of layers to place on each pipeline stage." )
956969 num_pipeline_repeats : int = Field (
@@ -1194,7 +1207,10 @@ class OlmoGrainDataset(BaseModel):
11941207 ``data_shuffle_seed``); only OLMo-specific fields are listed here.
11951208 """
11961209
1197- olmo_index_path : PathStr = Field ("" , description = "Path or gs:// URI to the JSON index from build_olmo_npy_index.py." )
1210+ olmo_index_path : PathStr = Field (
1211+ "" ,
1212+ description = "Path or gs:// URI to the JSON index from build_olmo_npy_index.py." ,
1213+ )
11981214 olmo_path_remap_from : PathStr = Field (
11991215 "" ,
12001216 description = "If set, rewrite index file paths starting with this prefix to olmo_path_remap_to." ,
@@ -1279,19 +1295,24 @@ class Distillation(BaseModel):
12791295 distill_layer_indices : None | list = Field (None , description = "Feature indices for feature loss." )
12801296 distill_alpha_end : Optional [float ] = Field (None , description = "Target alpha at end of training. None keeps alpha fixed." )
12811297 distill_alpha_schedule : Literal ["constant" , "linear" , "cosine" ] = Field (
1282- "constant" , description = "Schedule type for alpha annealing ('constant', 'linear', or 'cosine')."
1298+ "constant" ,
1299+ description = "Schedule type for alpha annealing ('constant', 'linear', or 'cosine')." ,
12831300 )
12841301 distill_temperature_end : Optional [float ] = Field (
1285- None , description = "Target temperature at end of training. None keeps temperature fixed."
1302+ None ,
1303+ description = "Target temperature at end of training. None keeps temperature fixed." ,
12861304 )
12871305 distill_temperature_schedule : Literal ["constant" , "linear" , "cosine" ] = Field (
1288- "constant" , description = "Schedule type for temperature annealing ('constant', 'linear', or 'cosine')."
1306+ "constant" ,
1307+ description = "Schedule type for temperature annealing ('constant', 'linear', or 'cosine')." ,
12891308 )
12901309 distill_beta_end : Optional [float ] = Field (
1291- None , description = "Target beta_feature at end of training. None keeps beta fixed."
1310+ None ,
1311+ description = "Target beta_feature at end of training. None keeps beta fixed." ,
12921312 )
12931313 distill_beta_schedule : Literal ["constant" , "linear" , "cosine" ] = Field (
1294- "constant" , description = "Schedule type for beta annealing ('constant', 'linear', or 'cosine')."
1314+ "constant" ,
1315+ description = "Schedule type for beta annealing ('constant', 'linear', or 'cosine')." ,
12951316 )
12961317
12971318 # --- Learn to init related parameters --
@@ -1314,11 +1335,13 @@ class Distillation(BaseModel):
13141335 )
13151336
13161337 attn_module_name : Optional [str ] = Field (
1317- None , description = "Attention nnx module attribute name to augment with LTI logic"
1338+ None ,
1339+ description = "Attention nnx module attribute name to augment with LTI logic" ,
13181340 )
13191341
13201342 lti_layer_indices : Optional [list [int ]] = Field (
1321- None , description = "List of layer indices to apply LTI modifications. If None, applied to all layers."
1343+ None ,
1344+ description = "List of layer indices to apply LTI modifications. If None, applied to all layers." ,
13221345 )
13231346 # ---------------------------------------
13241347
@@ -1365,6 +1388,10 @@ class ManifoldConstrainedHyperConnections(BaseModel):
13651388
13661389 mhc_expansion_rate : PositiveInt = Field (1 , description = "The number of parallel streams in Hyper Connection." )
13671390 sinkhorn_iterations : PositiveInt = Field (20 , description = "The number of iterations for the Sinkhorn-Knopp algorithm." )
1391+ enable_mhc_k4_shortcut : bool = Field (
1392+ True ,
1393+ description = "Whether to enable the permutation-based convex combination shortcut when mhc_expansion_rate is 4." ,
1394+ )
13681395
13691396
13701397class DilocoParams (BaseModel ):
@@ -1655,7 +1682,8 @@ class Profiling(BaseModel):
16551682 tpu_num_chips_to_profile_per_task : int = Field (1 , description = "Specifies the number of TPU chips to profile per task." )
16561683 tpu_num_sparse_cores_to_trace : int = Field (2 , description = "Specifies the number of TPU chips to profile per task." )
16571684 tpu_num_sparse_core_tiles_to_trace : int = Field (
1658- 1 , description = "Specifies the number of tiles within each sparse core to trace on the TPU."
1685+ 1 ,
1686+ description = "Specifies the number of tiles within each sparse core to trace on the TPU." ,
16591687 )
16601688 xprof_tpu_power_trace_level : XProfTPUPowerTraceMode = Field (
16611689 XProfTPUPowerTraceMode .POWER_TRACE_NONE ,
@@ -2491,7 +2519,11 @@ def validate_and_set_hlo_dump_defaults():
24912519 )
24922520 for param_name , schedule , end_value in [
24932521 ("distill_alpha" , self .distill_alpha_schedule , self .distill_alpha_end ),
2494- ("distill_temperature" , self .distill_temperature_schedule , self .distill_temperature_end ),
2522+ (
2523+ "distill_temperature" ,
2524+ self .distill_temperature_schedule ,
2525+ self .distill_temperature_end ,
2526+ ),
24952527 ("distill_beta" , self .distill_beta_schedule , self .distill_beta_end ),
24962528 ]:
24972529 if schedule != "constant" and end_value is None :
@@ -3004,7 +3036,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
30043036 self .use_grpo = False
30053037
30063038 if self .use_batch_split_schedule :
3007- if self .quantization and not self .quantization = = "fp8_full" :
3039+ if self .quantization and self .quantization ! = "fp8_full" :
30083040 raise ValueError ("Batch split quantization only supports `quantization=fp8_full`" )
30093041
30103042 if self .opt_type == "muon" and self .decoder_block not in [
0 commit comments