Skip to content

Commit 8d22da6

Browse files
committed
add gpt to predict.py and infer.py
Signed-off-by: Yang Zhang <yangzhang@nvidia.com>
1 parent bde26b9 commit 8d22da6

3 files changed

Lines changed: 34 additions & 19 deletions

File tree

sub-packages/bionemo-evo2/src/bionemo/evo2/models/gpt.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -91,15 +91,13 @@ def get_inference_wrapper(
9191
self, params_dtype, inference_batch_times_seqlen_threshold, inference_max_seq_length=8192
9292
) -> GPTInferenceWrapper:
9393
"""Gets the inference wrapper for the Mamba model."""
94-
# Find MCoreMambaModel instance
95-
mcore_model = self.module
96-
while mcore_model:
97-
if isinstance(mcore_model, ()):
94+
model = self
95+
while model is not None:
96+
if getattr(model, "module", None) is not None:
97+
model = model.module
98+
else:
9899
break
99-
mcore_model = getattr(mcore_model, "module", None)
100-
if mcore_model is None or not isinstance(
101-
mcore_model, (megatron.core.models.gpt.gpt_model.GPTModel, Evo2StyleMCoreGPTModel)
102-
):
100+
if not isinstance(model, megatron.core.models.gpt.gpt_model.GPTModel):
103101
raise ValueError("GPT model instance not found in the model structure.")
104102

105103
vocab_size = None
@@ -111,14 +109,14 @@ def get_inference_wrapper(
111109
raise ValueError("Unable to find vocab size.")
112110

113111
inference_wrapper_config = InferenceWrapperConfig(
114-
hidden_size=mcore_model.config.hidden_size,
112+
hidden_size=model.config.hidden_size,
115113
params_dtype=params_dtype,
116114
inference_batch_times_seqlen_threshold=inference_batch_times_seqlen_threshold,
117115
padded_vocab_size=vocab_size,
118116
inference_max_seq_length=inference_max_seq_length,
119117
)
120118

121-
model_inference_wrapper = GPTInferenceWrapper(mcore_model, inference_wrapper_config)
119+
model_inference_wrapper = GPTInferenceWrapper(model, inference_wrapper_config)
122120
return model_inference_wrapper
123121

124122
@override

sub-packages/bionemo-evo2/src/bionemo/evo2/run/infer.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,14 +86,12 @@ def parse_args():
8686
)
8787
ap.add_argument(
8888
"--fp8",
89-
type=bool,
9089
action="store_true",
9190
default=False,
9291
help="Whether to use vortex style FP8. Defaults to False.",
9392
)
9493
ap.add_argument(
9594
"--flash-decode",
96-
type=bool,
9795
action="store_true",
9896
default=False,
9997
help="Whether to use flash decode. Defaults to True.",
@@ -173,8 +171,8 @@ def infer(
173171
path=ckpt_dir,
174172
trainer=trainer,
175173
params_dtype=torch.bfloat16,
176-
inference_batch_times_seqlen_threshold=8192, # TODO
177-
inference_max_seq_length=8192, # TODO
174+
inference_batch_times_seqlen_threshold=len(prompt) + max_new_tokens, # TODO
175+
inference_max_seq_length=len(prompt) + max_new_tokens, # TODO
178176
recompute_granularity=None,
179177
recompute_num_layers=None,
180178
recompute_method=None,

sub-packages/bionemo-evo2/src/bionemo/evo2/run/predict.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from torch import Tensor
3838

3939
from bionemo.evo2.data.fasta_dataset import SimpleFastaDataset
40+
from bionemo.evo2.models.gpt import GPT_MODEL_OPTIONS
4041

4142
# Add import for Mamba models
4243
from bionemo.evo2.models.mamba import MAMBA_MODEL_OPTIONS, MambaModel
@@ -73,15 +74,17 @@ def parse_args():
7374
ap.add_argument(
7475
"--model-type",
7576
type=str,
76-
choices=["hyena", "mamba"],
77+
choices=["hyena", "mamba", "gpt"],
7778
default="hyena",
78-
help="Model architecture family to use. Choose between 'hyena' and 'mamba'.",
79+
help="Model architecture family to use. Choose between 'hyena', 'mamba', and 'gpt'.",
7980
)
8081
ap.add_argument(
8182
"--model-size",
8283
type=str,
8384
default="7b",
84-
choices=sorted(list(HYENA_MODEL_OPTIONS.keys()) + list(MAMBA_MODEL_OPTIONS.keys())),
85+
choices=sorted(
86+
list(HYENA_MODEL_OPTIONS.keys()) + list(MAMBA_MODEL_OPTIONS.keys()) + list(GPT_MODEL_OPTIONS.keys())
87+
),
8588
help="Model size to use. Defaults to '7b'.",
8689
)
8790
# output args:
@@ -416,7 +419,7 @@ def predict(
416419
vortex_style_fp8=fp8 and not full_fp8,
417420
**config_modifiers_init,
418421
)
419-
else: # mamba
422+
elif model_type == "mamba": # mamba
420423
if model_size not in MAMBA_MODEL_OPTIONS:
421424
raise ValueError(f"Invalid model size for Mamba: {model_size}")
422425
config = MAMBA_MODEL_OPTIONS[model_size](
@@ -425,6 +428,15 @@ def predict(
425428
distribute_saved_activations=False if sequence_parallel and tensor_parallel_size > 1 else True,
426429
**config_modifiers_init,
427430
)
431+
elif model_type == "gpt":
432+
if model_size not in GPT_MODEL_OPTIONS:
433+
raise ValueError(f"Invalid model size for GPT: {model_size}")
434+
config = GPT_MODEL_OPTIONS[model_size](
435+
forward_step_fn=hyena_predict_forward_step,
436+
data_step_fn=hyena_predict_data_step,
437+
)
438+
else:
439+
raise ValueError(f"Invalid model type: {model_type}")
428440

429441
trainer.strategy._setup_optimizers = False
430442

@@ -451,13 +463,20 @@ def predict(
451463
output_log_prob_seqs=output_log_prob_seqs,
452464
log_prob_collapse_option=log_prob_collapse_option,
453465
)
454-
else: # mamba
466+
elif model_type == "mamba": # mamba
455467
model = MambaPredictor(
456468
config,
457469
tokenizer=tokenizer,
458470
output_log_prob_seqs=output_log_prob_seqs,
459471
log_prob_collapse_option=log_prob_collapse_option,
460472
)
473+
elif model_type == "gpt":
474+
model = HyenaPredictor(
475+
config,
476+
tokenizer=tokenizer,
477+
output_log_prob_seqs=output_log_prob_seqs,
478+
log_prob_collapse_option=log_prob_collapse_option,
479+
)
461480

462481
resume.setup(trainer, model) # this pulls weights from the starting checkpoint.
463482

0 commit comments

Comments
 (0)