|
14 | 14 | # limitations under the License. |
15 | 15 | """Megatron-Bridge plugins for using with Model-Optimizer.""" |
16 | 16 |
|
17 | | -from collections.abc import Callable |
18 | 17 | from typing import Any |
19 | 18 |
|
20 | | -import torch.nn as nn |
21 | | -from datasets import DatasetDict |
22 | 19 | from megatron.bridge import AutoBridge |
23 | | -from megatron.bridge.data.builders.hf_dataset import HFDatasetConfig |
24 | | -from megatron.bridge.data.loaders import setup_data_iterators |
25 | | -from megatron.bridge.data.utils import get_dataset_provider |
26 | 20 | from megatron.bridge.models.gpt_provider import GPTModelProvider |
27 | 21 | from megatron.bridge.models.hf_pretrained.utils import is_safe_repo |
28 | 22 | from megatron.bridge.models.mamba.mamba_provider import MambaModelProvider |
29 | | -from megatron.bridge.training.config import ( |
30 | | - CheckpointConfig, |
31 | | - ConfigContainer, |
32 | | - LoggerConfig, |
33 | | - OptimizerConfig, |
34 | | - SchedulerConfig, |
35 | | - TrainingConfig, |
36 | | - runtime_config_update, |
37 | | -) |
38 | | -from megatron.bridge.training.eval import evaluate_and_print_results |
39 | | -from megatron.bridge.training.gpt_step import forward_step |
40 | | -from megatron.bridge.training.state import GlobalState |
41 | | -from megatron.bridge.training.tokenizers.config import TokenizerConfig |
42 | 23 | from megatron.core.models.gpt import GPTModel |
43 | 24 | from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec |
44 | 25 | from megatron.core.models.mamba import MambaModel |
45 | | -from megatron.core.parallel_state import get_data_parallel_group |
46 | 26 | from megatron.core.transformer.module import MegatronModule |
47 | 27 | from megatron.core.utils import unwrap_model |
48 | 28 | from transformers import AutoTokenizer |
49 | 29 |
|
50 | 30 | from modelopt.torch.nas.plugins.megatron import get_te_mamba_stack_spec |
51 | | -from modelopt.torch.utils import get_dataset_samples, print_rank_0, warn_rank_0 |
| 31 | +from modelopt.torch.utils import print_rank_0 |
52 | 32 |
|
53 | | -__all__ = ["get_hf_mbridge_calibration_loop", "load_mbridge_model_from_hf"] |
| 33 | +__all__ = ["load_mbridge_model_from_hf"] |
54 | 34 |
|
55 | 35 |
|
56 | 36 | def load_mbridge_model_from_hf( |
@@ -118,134 +98,3 @@ def load_mbridge_model_from_hf( |
118 | 98 | ) |
119 | 99 |
|
120 | 100 | return bridge, provider, model, unwrapped_model, tokenizer |
121 | | - |
122 | | - |
123 | | -def _get_dataset_cfg( |
124 | | - dataset_name: str, |
125 | | - num_samples: int, |
126 | | - seq_length: int, |
127 | | - apply_chat_template: bool = True, |
128 | | - tokenizer: AutoTokenizer | None = None, |
129 | | -) -> HFDatasetConfig: |
130 | | - """Get a dataset config for the dataset.""" |
131 | | - dataset = get_dataset_samples( |
132 | | - dataset_name, num_samples, apply_chat_template=apply_chat_template, tokenizer=tokenizer |
133 | | - ) |
134 | | - dataset_cfg = HFDatasetConfig( |
135 | | - dataset_name=f"{dataset_name}_{num_samples}", |
136 | | - dataset_dict=DatasetDict({"train": dataset}), |
137 | | - process_example_fn=lambda example, tokenizer: {"input": example, "output": ""}, |
138 | | - seq_length=seq_length, |
139 | | - dataloader_type="batch", |
140 | | - num_workers=1, |
141 | | - do_validation=False, |
142 | | - do_test=False, |
143 | | - val_proportion=None, |
144 | | - split_val_from_train=False, |
145 | | - rewrite=True, |
146 | | - ) |
147 | | - |
148 | | - return dataset_cfg |
149 | | - |
150 | | - |
151 | | -def get_hf_mbridge_calibration_loop( |
152 | | - *, |
153 | | - model: list[MegatronModule], |
154 | | - provider: GPTModelProvider | MambaModelProvider, |
155 | | - tokenizer: AutoTokenizer, |
156 | | - hf_model_name_or_path: str, |
157 | | - trust_remote_code: bool = False, |
158 | | - dataset_name: str = "nemotron-post-training-dataset-v2", |
159 | | - num_samples: int = 512, |
160 | | - micro_batch_size: int = 1, |
161 | | - global_batch_size: int = 1, |
162 | | -) -> Callable[[nn.Module], None]: |
163 | | - """Get a modelopt calibration loop for a Megatron-Bridge model. |
164 | | -
|
165 | | - Args: |
166 | | - model: The model to calibrate. |
167 | | - provider: The provider to use for the model. |
168 | | - tokenizer: The tokenizer to use for the model. |
169 | | - hf_model_name_or_path: The name or path of the HF model. |
170 | | - trust_remote_code: Whether to trust remote code. |
171 | | - dataset_name: The name of the dataset to use for evaluation. |
172 | | - num_samples: The number of samples to use for evaluation. |
173 | | - micro_batch_size: The micro batch size to use for evaluation. |
174 | | - global_batch_size: The global batch size to use for evaluation. |
175 | | -
|
176 | | - Returns: |
177 | | - A function that can be used to calibrate the model with a modelopt.torch API. |
178 | | - """ |
179 | | - if global_batch_size < micro_batch_size: |
180 | | - warn_rank_0( |
181 | | - f"{global_batch_size=} is smaller than {micro_batch_size=}. Setting gbs to {micro_batch_size}." |
182 | | - ) |
183 | | - global_batch_size = micro_batch_size |
184 | | - num_iters = num_samples // global_batch_size |
185 | | - |
186 | | - cfg = ConfigContainer( |
187 | | - model=provider, |
188 | | - train=TrainingConfig( |
189 | | - micro_batch_size=micro_batch_size, |
190 | | - global_batch_size=global_batch_size, |
191 | | - train_iters=num_iters, |
192 | | - eval_iters=num_iters, |
193 | | - skip_train=True, |
194 | | - ), |
195 | | - # TODO: Replace validation args in train with validation config in nemo:26.04 |
196 | | - # validation=ValidationConfig(eval_iters=num_iters, eval_interval=1, skip_train=True), |
197 | | - dataset=_get_dataset_cfg( |
198 | | - dataset_name, |
199 | | - num_samples, |
200 | | - provider.seq_length, |
201 | | - apply_chat_template=True, |
202 | | - tokenizer=tokenizer, |
203 | | - ), |
204 | | - tokenizer=TokenizerConfig( |
205 | | - tokenizer_type="HuggingFaceTokenizer", |
206 | | - tokenizer_model=hf_model_name_or_path, |
207 | | - # NOTE: Issue with Nemotron Nano v2 tokenizer returning bool hence using use_fast=True as a WAR |
208 | | - hf_tokenizer_kwargs={ |
209 | | - "trust_remote_code": trust_remote_code, |
210 | | - "use_fast": tokenizer.is_fast, |
211 | | - }, |
212 | | - ), |
213 | | - # Unused |
214 | | - optimizer=OptimizerConfig(optimizer="adam", lr=1e-4, use_distributed_optimizer=False), |
215 | | - scheduler=SchedulerConfig(lr_decay_style="constant"), |
216 | | - logger=LoggerConfig(), |
217 | | - checkpoint=CheckpointConfig(), |
218 | | - ) |
219 | | - runtime_config_update(cfg) |
220 | | - |
221 | | - state = GlobalState() |
222 | | - state.cfg = cfg |
223 | | - |
224 | | - dataset_provider = get_dataset_provider(cfg.dataset) |
225 | | - |
226 | | - def _train_valid_test_datasets_provider( |
227 | | - train_val_test_num_samples: tuple, dataset_cfg: HFDatasetConfig |
228 | | - ): |
229 | | - return dataset_provider(train_val_test_num_samples, dataset_cfg, tokenizer=state.tokenizer) |
230 | | - |
231 | | - train_data_iterator, _, _ = setup_data_iterators( |
232 | | - cfg=cfg, |
233 | | - train_state=state.train_state, |
234 | | - model_length=len(model), |
235 | | - train_valid_test_datasets_provider=_train_valid_test_datasets_provider, |
236 | | - dp_group=get_data_parallel_group(), |
237 | | - ) |
238 | | - |
239 | | - def forward_loop(m): |
240 | | - evaluate_and_print_results( |
241 | | - state, |
242 | | - prefix="iteration 1", |
243 | | - forward_step_func=forward_step, |
244 | | - data_iterator=train_data_iterator, |
245 | | - model=model, |
246 | | - config=cfg, |
247 | | - verbose=True, |
248 | | - write_to_tensorboard=False, |
249 | | - ) |
250 | | - |
251 | | - return forward_loop |
0 commit comments