Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/maxtext/common/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ class DecoderBlockType(enum.Enum):
LLAMA4 = "llama4"
OLMO3 = "olmo3"

LLAMA2LTI = "llama2-learn-to-init"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for naming convention could you use "_" instead of "-" e.g. LLAMA2_LTI = "llama2_lti"



class AttentionType(enum.Enum):
GLOBAL = "global" # default, with causality
Expand Down
20 changes: 20 additions & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1194,6 +1194,26 @@ class Distillation(BaseModel):
"constant", description="Schedule type for beta annealing ('constant', 'linear', or 'cosine')."
)

# --- Learn to init related parameters --
learn_to_init_mode: bool = Field(False, description="Runs in the learn-to-init mode only")

lti_use_general_linear_map: bool = Field(
False,
description="enable general map (i.e. single learnable projection instead of the bi-linear mapping. "
"Needs much more HBM.",
)

distill_weights_copy_map: dict[str, Any] = Field(
default_factory=dict,
description="Dictionary of copying original teacher weights to the student model.",
)

distill_student_weights_share_map: dict[str, Any] = Field(
default_factory=dict,
description="Experimental weight sharing map inside the student model for learn-to-init phase",
)
# ---------------------------------------

# --- Distillation freezing filter --
student_params_to_update: None | list = Field(
None,
Expand Down
3 changes: 3 additions & 0 deletions src/maxtext/layers/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,8 @@ def get_decoder_layers(self):
return [DecoderLayer]
case DecoderBlockType.LLAMA2:
return [llama2.LlamaDecoderLayerToLinen]
case DecoderBlockType.LLAMA2LTI:
return [llama2.LlamaLTIDecoderLayerToLinen]
case DecoderBlockType.MISTRAL:
# TODO(ranran): update to Mistral with sliding window attention
return [mistral.MistralDecoderLayerToLinen]
Expand Down Expand Up @@ -543,6 +545,7 @@ def get_norm_layer(self, num_features: int):
DecoderBlockType.SIMPLE_MLP,
DecoderBlockType.LLAMA4,
DecoderBlockType.OLMO3,
DecoderBlockType.LLAMA2LTI,
):
return functools.partial(rms_norm, num_features=num_features, shard_mode=self.config.shard_mode)
elif self.config.decoder_block == DecoderBlockType.GPT3:
Expand Down
Loading
Loading