Skip to content

Commit d011897

Browse files
Merge pull request #2853 from AI-Hypercomputer:fix_mlperf_data
PiperOrigin-RevId: 846816667
2 parents e6ba816 + 3b763bb commit d011897

1 file changed

Lines changed: 3 additions & 6 deletions

File tree

src/MaxText/configs/types.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ class DatasetType(str, Enum):
146146
HF = "hf"
147147
GRAIN = "grain"
148148
TFDS = "tfds"
149+
C4MLPERF = "c4_mlperf"
149150

150151

151152
class SamplingStrategy(str, Enum):
@@ -279,9 +280,7 @@ class Checkpointing(BaseModel):
279280
save_checkpoint_on_completion: bool = Field(
280281
True, description="If True, saves a final checkpoint upon training completion."
281282
)
282-
enable_continuous_checkpointing: bool = Field(
283-
False, description="If True, enables continuous checkpointing."
284-
)
283+
enable_continuous_checkpointing: bool = Field(False, description="If True, enables continuous checkpointing.")
285284

286285

287286
class OrbaxStorage(BaseModel):
@@ -463,9 +462,7 @@ class Attention(BaseModel):
463462
ragged_block_size: int = Field(256, description="Block size for ragged attention.")
464463
enable_padding_causal_mask: bool = Field(True, description="Temporary flag for TE padding.")
465464
use_tokamax_splash: bool = Field(False, description="Whether to use tokamax splash attention.")
466-
use_jax_splash: bool = Field(
467-
False, description="Whether to use jax splash attention."
468-
)
465+
use_jax_splash: bool = Field(False, description="Whether to use jax splash attention.")
469466

470467

471468
class MoBa(BaseModel):

0 commit comments

Comments
 (0)