Skip to content

Commit 53fdc45

Browse files
committed
Fix changed import in infer.py
Signed-off-by: John St. John <jstjohn@nvidia.com>
1 parent e8fb3c3 commit 53fdc45

1 file changed

Lines changed: 82 additions & 5 deletions

File tree

  • bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run

bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/infer.py

Lines changed: 82 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959

6060
import argparse
6161
import gc
62+
import inspect
6263
import json
6364
import logging
6465
import os
@@ -91,16 +92,40 @@
9192
read_run_config,
9293
)
9394
from megatron.bridge.utils.common_utils import get_world_size_safe
94-
from megatron.bridge.utils.instantiate_utils import instantiate
95+
from megatron.bridge.utils.instantiate_utils import instantiate, register_allowed_target_prefix
9596
from megatron.core import dist_checkpointing, parallel_state
9697
from megatron.core.inference.contexts import StaticInferenceContext
9798
from megatron.core.inference.engines.static_engine import StaticInferenceEngine
9899
from megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper import (
99100
AbstractModelInferenceWrapper,
100101
)
101-
from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import (
102-
InferenceWrapperConfig,
103-
)
102+
103+
104+
try:
105+
from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import (
106+
InferenceWrapperConfig,
107+
)
108+
except ImportError:
109+
110+
@dataclass
111+
class InferenceWrapperConfig:
112+
"""Compatibility shim for MCore versions that removed InferenceWrapperConfig."""
113+
114+
hidden_size: int
115+
inference_max_requests: int
116+
inference_max_seq_length: int
117+
inference_batch_times_seqlen_threshold: int
118+
params_dtype: torch.dtype
119+
padded_vocab_size: int
120+
nccl_all_reduce_for_prefill: bool = False
121+
moe_pad_experts_for_cuda_graph_inference: bool = False
122+
123+
def add_attributes(self, attributes: dict[str, Any]) -> None:
124+
"""Match the old MCore config helper used by Evo2TextGenerationController."""
125+
for name, value in attributes.items():
126+
setattr(self, name, value)
127+
128+
104129
from megatron.core.inference.sampling_params import SamplingParams
105130
from megatron.core.transformer.module import Float16Module
106131
from megatron.core.utils import get_model_config
@@ -115,6 +140,52 @@
115140
logger: logging.Logger = logging.getLogger(__name__)
116141
logger.setLevel(logging.INFO)
117142

143+
_WRAPPER_INIT_ACCEPTS_CONFIG = (
144+
"inference_wrapper_config" in inspect.signature(AbstractModelInferenceWrapper.__init__).parameters
145+
)
146+
147+
148+
class _TextGenerationTokenizerAdapter:
149+
"""Expose the tokenizer methods expected by MCore's static text-generation path."""
150+
151+
def __init__(self, tokenizer: _HuggingFaceTokenizer):
152+
self._tokenizer = tokenizer
153+
154+
def __getattr__(self, name: str) -> Any:
155+
return getattr(self._tokenizer, name)
156+
157+
@property
158+
def vocab_size(self) -> int:
159+
return self._tokenizer.vocab_size
160+
161+
@property
162+
def bos(self) -> Optional[int]:
163+
return getattr(self._tokenizer, "bos", None)
164+
165+
@property
166+
def eod(self) -> Optional[int]:
167+
return getattr(self._tokenizer, "eod", None)
168+
169+
def tokenize(self, text: str) -> list[int]:
170+
if hasattr(self._tokenizer, "tokenize"):
171+
return self._tokenizer.tokenize(text)
172+
return self._tokenizer.text_to_ids(text)
173+
174+
def detokenize(self, tokens: list[int], skip_special_tokens: bool = True) -> str:
175+
if hasattr(self._tokenizer, "detokenize"):
176+
return self._tokenizer.detokenize(tokens, skip_special_tokens=skip_special_tokens)
177+
return self._tokenizer.ids_to_text(tokens)
178+
179+
def offsets(self, tokens: list[int], text: str) -> list[int]:
180+
if hasattr(self._tokenizer, "offsets"):
181+
return self._tokenizer.offsets(tokens, text)
182+
offsets = []
183+
position = 0
184+
for token in tokens:
185+
offsets.append(position)
186+
position += len(self.detokenize([token], skip_special_tokens=False))
187+
return offsets
188+
118189

119190
# =============================================================================
120191
# Hardware-Aware Defaults
@@ -243,7 +314,11 @@ def __init__(
243314
inference_wrapper_config: Configuration with hidden size, vocab size, etc.
244315
inference_context: Context for managing state and sequence offsets.
245316
"""
246-
super().__init__(model, inference_wrapper_config, inference_context)
317+
self.inference_wrapper_config = inference_wrapper_config
318+
if _WRAPPER_INIT_ACCEPTS_CONFIG:
319+
super().__init__(model, inference_wrapper_config, inference_context)
320+
else:
321+
super().__init__(model, inference_context)
247322

248323
def prep_inference_input(self, prompts_tokens: torch.Tensor) -> Dict[str, Any]:
249324
"""Prepare the inference input data.
@@ -410,6 +485,7 @@ def setup_inference_engine(
410485
raise FileNotFoundError(f"run_config.yaml not found at {run_config_filename}")
411486

412487
run_config = read_run_config(run_config_filename)
488+
register_allowed_target_prefix("bionemo.")
413489
model_provider = instantiate(run_config["model"])
414490
logger.info(f"Instantiated model provider: {type(model_provider).__name__}")
415491

@@ -446,6 +522,7 @@ def setup_inference_engine(
446522
tokenizer = _HuggingFaceTokenizer(tokenizer_dir)
447523
else:
448524
tokenizer = _HuggingFaceTokenizer(DEFAULT_HF_TOKENIZER_MODEL_PATH)
525+
tokenizer = _TextGenerationTokenizerAdapter(tokenizer)
449526

450527
model_provider.vocab_size = tokenizer.vocab_size
451528
model_provider.should_pad_vocab = True

0 commit comments

Comments
 (0)