Skip to content

Commit 151b03d

Browse files
Add back DistillationProvider patch
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
1 parent d175f44 commit 151b03d

File tree

1 file changed

+42
-5
lines changed

1 file changed

+42
-5
lines changed

examples/megatron_bridge/distill.py

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,22 @@
1515
"""Distillation script for Megatron-Bridge.
1616
1717
Loads student and teacher models directly from HuggingFace checkpoints (local or remote) and saves the distilled model
18-
to `<output_dir>/checkpoints` in megatron distributed checkpoint format.
18+
to `<output_dir>/checkpoints` in megatron distributed checkpoint or HuggingFace format.
1919
2020
See `README.md` in this directory for example usage and data preparation instructions.
2121
"""
2222

2323
import argparse
2424
import contextlib
2525
import os
26+
from dataclasses import fields
2627

2728
import torch
2829
from megatron.bridge import AutoBridge
29-
from megatron.bridge.models.distillation_provider import convert_to_distillation_provider
30+
from megatron.bridge.models.distillation_provider import (
31+
DistillationProvider,
32+
convert_to_distillation_provider,
33+
)
3034
from megatron.bridge.recipes.utils.optimizer_utils import (
3135
distributed_fused_adam_with_cosine_annealing,
3236
)
@@ -46,15 +50,48 @@
4650
from megatron.core.distributed import DistributedDataParallelConfig
4751
from transformers import AutoConfig
4852

49-
with contextlib.suppress(ImportError):
50-
import modelopt.torch.puzzletron.plugins.mbridge # noqa: F401
51-
5253
import modelopt.torch.utils.distributed as dist
5354
from modelopt.torch.utils import print_rank_0
5455

56+
with contextlib.suppress(ImportError):
57+
import modelopt.torch.puzzletron.plugins.mbridge # noqa: F401
58+
5559
SEED = 1234
5660

5761

62+
def _patched_to_cfg_dict(self):
63+
"""Patched DistillationProvider.to_cfg_dict method for heterogeneous teacher and student models.
64+
65+
TODO: Upstream this patch to Megatron-Bridge.
66+
"""
67+
from megatron.bridge.training.utils.config_utils import _ConfigContainerBase
68+
69+
result = {"_target_": f"{self._super_class.__module__}.{self._super_class.__qualname__}"}
70+
# Use fields from the actual student provider class, not DistillationProvider.
71+
# DistillationProvider's __dataclass_fields__ only includes TransformerConfig fields
72+
# (set at class definition time), missing GPTModelProvider-level fields like
73+
# vocab_size, share_embeddings_and_output_weights, etc.
74+
excluded_fields = {"teacher", "kd_config"}
75+
for field in fields(self._super_class):
76+
if field.name.startswith("_") or field.name in excluded_fields:
77+
continue
78+
if hasattr(self, field.name):
79+
result[field.name] = _ConfigContainerBase._convert_value_to_dict(
80+
getattr(self, field.name)
81+
)
82+
for field in fields(self):
83+
if field.name.startswith("_") or field.name in excluded_fields:
84+
continue
85+
if field.name not in result:
86+
result[field.name] = _ConfigContainerBase._convert_value_to_dict(
87+
getattr(self, field.name)
88+
)
89+
return result
90+
91+
92+
DistillationProvider.to_cfg_dict = _patched_to_cfg_dict
93+
94+
5895
def get_args():
5996
"""Parse command-line arguments."""
6097
parser = argparse.ArgumentParser(description="Distillation for Megatron-Bridge.")

0 commit comments

Comments
 (0)