Skip to content

[Distillation] base learn-to-init llama attention for distillation#3688

Open
vlad-karp wants to merge 17 commits intomainfrom
vladk/lti
Open

[Distillation] base learn-to-init llama attention for distillation#3688
vlad-karp wants to merge 17 commits intomainfrom
vladk/lti

Conversation

@vlad-karp
Copy link
Copy Markdown
Collaborator

@vlad-karp vlad-karp commented Apr 17, 2026

Description

This PR introduces the base implementation for Learn-to-Init (LTI) attention for distillation (llama only be can easily be generalized).

Relevant details and context:

  • Problem being solved: Setting up the foundational LTI components required for effective attention distillation of LLaMA models.
  • Implementation details:
    • Adds the core LTI logic in a new learn_to_init_layer.py module.
    • Updates the distillation pipeline (distillation_utils.py and train_distill.py) and decoders to support the new attention layer.
    • 2 LTI implementation for GQA - using bi-linear and global linear map options.
    • Configures the system to get LTI student init-time teacher shapes directly from the configuration.

Tests

  • Added a new unit test for LearnToInitDense as well as teacher to student weight injection logic.
    src/maxtext/tests/post_training/unit/learn-to-init_test.py

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link
Copy Markdown

codecov Bot commented Apr 17, 2026

@vlad-karp vlad-karp changed the title base learn-to-init llama attention for distillation [Distillation] base learn-to-init llama attention for distillation Apr 20, 2026
Comment thread src/maxtext/models/learn_to_init_layer.py Outdated
Comment thread src/maxtext/models/learn_to_init_layer.py Outdated
Comment thread src/maxtext/models/llama2.py
Comment thread src/maxtext/trainers/post_train/distillation/distillation_utils.py Outdated
Comment thread src/maxtext/trainers/post_train/distillation/distillation_utils.py Outdated
Comment thread src/maxtext/layers/learn_to_init_layer.py
self,
raw_iterator: Any | None,
root_directory: str | None = None,
student_config: Any | None = None,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can it be None?
perhaps change default and data type if it can't.

It effectively collapses the learn-to-init parameterization back into a standard
decoder architecture, modifying the `student_model` in-place.

NOTE: works for ToNXX decoder model and layer-scan mode only
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you throw an exception if it's not layer-scan mode?

LLAMA4 = "llama4"
OLMO3 = "olmo3"

LLAMA2LTI = "llama2-learn-to-init"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for naming convention could you use "_" instead of "-" e.g. LLAMA2_LTI = "llama2_lti"

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you name the test with "_" insead of "-"? e..g. learn_to_init_test.py

self._buffered_train_metrics.additional_metrics[name] = ([], distillation_utils.weighted_mean)

self._buffered_train_metrics.additional_metrics[name][0].append(value)
max_logging.log(f"Distillation metrics: {aux}")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it logged at every step or once? as it's inside _post_process_train_step

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants