|
59 | 59 |
|
60 | 60 | import argparse |
61 | 61 | import gc |
| 62 | +import inspect |
62 | 63 | import json |
63 | 64 | import logging |
64 | 65 | import os |
|
91 | 92 | read_run_config, |
92 | 93 | ) |
93 | 94 | from megatron.bridge.utils.common_utils import get_world_size_safe |
94 | | -from megatron.bridge.utils.instantiate_utils import instantiate |
| 95 | +from megatron.bridge.utils.instantiate_utils import instantiate, register_allowed_target_prefix |
95 | 96 | from megatron.core import dist_checkpointing, parallel_state |
96 | 97 | from megatron.core.inference.contexts import StaticInferenceContext |
97 | 98 | from megatron.core.inference.engines.static_engine import StaticInferenceEngine |
98 | 99 | from megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper import ( |
99 | 100 | AbstractModelInferenceWrapper, |
100 | 101 | ) |
101 | | -from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import ( |
102 | | - InferenceWrapperConfig, |
103 | | -) |
| 102 | + |
| 103 | + |
| 104 | +try: |
| 105 | + from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import ( |
| 106 | + InferenceWrapperConfig, |
| 107 | + ) |
| 108 | +except ImportError: |
| 109 | + |
| 110 | + @dataclass |
| 111 | + class InferenceWrapperConfig: |
| 112 | + """Compatibility shim for MCore versions that removed InferenceWrapperConfig.""" |
| 113 | + |
| 114 | + hidden_size: int |
| 115 | + inference_max_requests: int |
| 116 | + inference_max_seq_length: int |
| 117 | + inference_batch_times_seqlen_threshold: int |
| 118 | + params_dtype: torch.dtype |
| 119 | + padded_vocab_size: int |
| 120 | + nccl_all_reduce_for_prefill: bool = False |
| 121 | + moe_pad_experts_for_cuda_graph_inference: bool = False |
| 122 | + |
| 123 | + def add_attributes(self, attributes: dict[str, Any]) -> None: |
| 124 | + """Match the old MCore config helper used by Evo2TextGenerationController.""" |
| 125 | + for name, value in attributes.items(): |
| 126 | + setattr(self, name, value) |
| 127 | + |
| 128 | + |
104 | 129 | from megatron.core.inference.sampling_params import SamplingParams |
105 | 130 | from megatron.core.transformer.module import Float16Module |
106 | 131 | from megatron.core.utils import get_model_config |
|
115 | 140 | logger: logging.Logger = logging.getLogger(__name__) |
116 | 141 | logger.setLevel(logging.INFO) |
117 | 142 |
|
| 143 | +_WRAPPER_INIT_ACCEPTS_CONFIG = ( |
| 144 | + "inference_wrapper_config" in inspect.signature(AbstractModelInferenceWrapper.__init__).parameters |
| 145 | +) |
| 146 | + |
| 147 | + |
| 148 | +class _TextGenerationTokenizerAdapter: |
| 149 | + """Expose the tokenizer methods expected by MCore's static text-generation path.""" |
| 150 | + |
| 151 | + def __init__(self, tokenizer: _HuggingFaceTokenizer): |
| 152 | + self._tokenizer = tokenizer |
| 153 | + |
| 154 | + def __getattr__(self, name: str) -> Any: |
| 155 | + return getattr(self._tokenizer, name) |
| 156 | + |
| 157 | + @property |
| 158 | + def vocab_size(self) -> int: |
| 159 | + return self._tokenizer.vocab_size |
| 160 | + |
| 161 | + @property |
| 162 | + def bos(self) -> Optional[int]: |
| 163 | + return getattr(self._tokenizer, "bos", None) |
| 164 | + |
| 165 | + @property |
| 166 | + def eod(self) -> Optional[int]: |
| 167 | + return getattr(self._tokenizer, "eod", None) |
| 168 | + |
| 169 | + def tokenize(self, text: str) -> list[int]: |
| 170 | + if hasattr(self._tokenizer, "tokenize"): |
| 171 | + return self._tokenizer.tokenize(text) |
| 172 | + return self._tokenizer.text_to_ids(text) |
| 173 | + |
| 174 | + def detokenize(self, tokens: list[int], skip_special_tokens: bool = True) -> str: |
| 175 | + if hasattr(self._tokenizer, "detokenize"): |
| 176 | + return self._tokenizer.detokenize(tokens, skip_special_tokens=skip_special_tokens) |
| 177 | + return self._tokenizer.ids_to_text(tokens) |
| 178 | + |
| 179 | + def offsets(self, tokens: list[int], text: str) -> list[int]: |
| 180 | + if hasattr(self._tokenizer, "offsets"): |
| 181 | + return self._tokenizer.offsets(tokens, text) |
| 182 | + offsets = [] |
| 183 | + position = 0 |
| 184 | + for token in tokens: |
| 185 | + offsets.append(position) |
| 186 | + position += len(self.detokenize([token], skip_special_tokens=False)) |
| 187 | + return offsets |
| 188 | + |
118 | 189 |
|
119 | 190 | # ============================================================================= |
120 | 191 | # Hardware-Aware Defaults |
@@ -243,7 +314,11 @@ def __init__( |
243 | 314 | inference_wrapper_config: Configuration with hidden size, vocab size, etc. |
244 | 315 | inference_context: Context for managing state and sequence offsets. |
245 | 316 | """ |
246 | | - super().__init__(model, inference_wrapper_config, inference_context) |
| 317 | + self.inference_wrapper_config = inference_wrapper_config |
| 318 | + if _WRAPPER_INIT_ACCEPTS_CONFIG: |
| 319 | + super().__init__(model, inference_wrapper_config, inference_context) |
| 320 | + else: |
| 321 | + super().__init__(model, inference_context) |
247 | 322 |
|
248 | 323 | def prep_inference_input(self, prompts_tokens: torch.Tensor) -> Dict[str, Any]: |
249 | 324 | """Prepare the inference input data. |
@@ -410,6 +485,7 @@ def setup_inference_engine( |
410 | 485 | raise FileNotFoundError(f"run_config.yaml not found at {run_config_filename}") |
411 | 486 |
|
412 | 487 | run_config = read_run_config(run_config_filename) |
| 488 | + register_allowed_target_prefix("bionemo.") |
413 | 489 | model_provider = instantiate(run_config["model"]) |
414 | 490 | logger.info(f"Instantiated model provider: {type(model_provider).__name__}") |
415 | 491 |
|
@@ -446,6 +522,7 @@ def setup_inference_engine( |
446 | 522 | tokenizer = _HuggingFaceTokenizer(tokenizer_dir) |
447 | 523 | else: |
448 | 524 | tokenizer = _HuggingFaceTokenizer(DEFAULT_HF_TOKENIZER_MODEL_PATH) |
| 525 | + tokenizer = _TextGenerationTokenizerAdapter(tokenizer) |
449 | 526 |
|
450 | 527 | model_provider.vocab_size = tokenizer.vocab_size |
451 | 528 | model_provider.should_pad_vocab = True |
|
0 commit comments