|
15 | 15 | """Distillation script for Megatron-Bridge. |
16 | 16 |
|
17 | 17 | Loads student and teacher models directly from HuggingFace checkpoints (local or remote) and saves the distilled model |
18 | | -to `<output_dir>/checkpoints` in megatron distributed checkpoint format. |
| 18 | +to `<output_dir>/checkpoints` in megatron distributed checkpoint or HuggingFace format. |
19 | 19 |
|
20 | 20 | See `README.md` in this directory for example usage and data preparation instructions. |
21 | 21 | """ |
22 | 22 |
|
23 | 23 | import argparse |
24 | 24 | import contextlib |
25 | 25 | import os |
| 26 | +from dataclasses import fields |
26 | 27 |
|
27 | 28 | import torch |
28 | 29 | from megatron.bridge import AutoBridge |
29 | | -from megatron.bridge.models.distillation_provider import convert_to_distillation_provider |
| 30 | +from megatron.bridge.models.distillation_provider import ( |
| 31 | + DistillationProvider, |
| 32 | + convert_to_distillation_provider, |
| 33 | +) |
30 | 34 | from megatron.bridge.recipes.utils.optimizer_utils import ( |
31 | 35 | distributed_fused_adam_with_cosine_annealing, |
32 | 36 | ) |
|
46 | 50 | from megatron.core.distributed import DistributedDataParallelConfig |
47 | 51 | from transformers import AutoConfig |
48 | 52 |
|
49 | | -with contextlib.suppress(ImportError): |
50 | | - import modelopt.torch.puzzletron.plugins.mbridge # noqa: F401 |
51 | | - |
52 | 53 | import modelopt.torch.utils.distributed as dist |
53 | 54 | from modelopt.torch.utils import print_rank_0 |
54 | 55 |
|
| 56 | +with contextlib.suppress(ImportError): |
| 57 | + import modelopt.torch.puzzletron.plugins.mbridge # noqa: F401 |
| 58 | + |
55 | 59 | SEED = 1234 |
56 | 60 |
|
57 | 61 |
|
| 62 | +def _patched_to_cfg_dict(self): |
| 63 | + """Patched DistillationProvider.to_cfg_dict method for heterogeneous teacher and student models. |
| 64 | +
|
| 65 | + TODO: Upstream this patch to Megatron-Bridge. |
| 66 | + """ |
| 67 | + from megatron.bridge.training.utils.config_utils import _ConfigContainerBase |
| 68 | + |
| 69 | + result = {"_target_": f"{self._super_class.__module__}.{self._super_class.__qualname__}"} |
| 70 | + # Use fields from the actual student provider class, not DistillationProvider. |
| 71 | + # DistillationProvider's __dataclass_fields__ only includes TransformerConfig fields |
| 72 | + # (set at class definition time), missing GPTModelProvider-level fields like |
| 73 | + # vocab_size, share_embeddings_and_output_weights, etc. |
| 74 | + excluded_fields = {"teacher", "kd_config"} |
| 75 | + for field in fields(self._super_class): |
| 76 | + if field.name.startswith("_") or field.name in excluded_fields: |
| 77 | + continue |
| 78 | + if hasattr(self, field.name): |
| 79 | + result[field.name] = _ConfigContainerBase._convert_value_to_dict( |
| 80 | + getattr(self, field.name) |
| 81 | + ) |
| 82 | + for field in fields(self): |
| 83 | + if field.name.startswith("_") or field.name in excluded_fields: |
| 84 | + continue |
| 85 | + if field.name not in result: |
| 86 | + result[field.name] = _ConfigContainerBase._convert_value_to_dict( |
| 87 | + getattr(self, field.name) |
| 88 | + ) |
| 89 | + return result |
| 90 | + |
| 91 | + |
| 92 | +DistillationProvider.to_cfg_dict = _patched_to_cfg_dict |
| 93 | + |
| 94 | + |
58 | 95 | def get_args(): |
59 | 96 | """Parse command-line arguments.""" |
60 | 97 | parser = argparse.ArgumentParser(description="Distillation for Megatron-Bridge.") |
|
0 commit comments