File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -146,6 +146,7 @@ class DatasetType(str, Enum):
146146 HF = "hf"
147147 GRAIN = "grain"
148148 TFDS = "tfds"
149+ C4MLPERF = "c4_mlperf"
149150
150151
151152class SamplingStrategy (str , Enum ):
@@ -279,9 +280,7 @@ class Checkpointing(BaseModel):
279280 save_checkpoint_on_completion : bool = Field (
280281 True , description = "If True, saves a final checkpoint upon training completion."
281282 )
282- enable_continuous_checkpointing : bool = Field (
283- False , description = "If True, enables continuous checkpointing."
284- )
283+ enable_continuous_checkpointing : bool = Field (False , description = "If True, enables continuous checkpointing." )
285284
286285
287286class OrbaxStorage (BaseModel ):
@@ -463,9 +462,7 @@ class Attention(BaseModel):
463462 ragged_block_size : int = Field (256 , description = "Block size for ragged attention." )
464463 enable_padding_causal_mask : bool = Field (True , description = "Temporary flag for TE padding." )
465464 use_tokamax_splash : bool = Field (False , description = "Whether to use tokamax splash attention." )
466- use_jax_splash : bool = Field (
467- False , description = "Whether to use jax splash attention."
468- )
465+ use_jax_splash : bool = Field (False , description = "Whether to use jax splash attention." )
469466
470467
471468class MoBa (BaseModel ):
You can’t perform that action at this time.
0 commit comments