3737from torch import Tensor
3838
3939from bionemo .evo2 .data .fasta_dataset import SimpleFastaDataset
40+ from bionemo .evo2 .models .gpt import GPT_MODEL_OPTIONS
4041
4142# Add import for Mamba models
4243from bionemo .evo2 .models .mamba import MAMBA_MODEL_OPTIONS , MambaModel
@@ -73,15 +74,17 @@ def parse_args():
7374 ap .add_argument (
7475 "--model-type" ,
7576 type = str ,
76- choices = ["hyena" , "mamba" ],
77+ choices = ["hyena" , "mamba" , "gpt" ],
7778 default = "hyena" ,
78- help = "Model architecture family to use. Choose between 'hyena' and 'mamba '." ,
79+ help = "Model architecture family to use. Choose between 'hyena', 'mamba', and 'gpt '." ,
7980 )
8081 ap .add_argument (
8182 "--model-size" ,
8283 type = str ,
8384 default = "7b" ,
84- choices = sorted (list (HYENA_MODEL_OPTIONS .keys ()) + list (MAMBA_MODEL_OPTIONS .keys ())),
85+ choices = sorted (
86+ list (HYENA_MODEL_OPTIONS .keys ()) + list (MAMBA_MODEL_OPTIONS .keys ()) + list (GPT_MODEL_OPTIONS .keys ())
87+ ),
8588 help = "Model size to use. Defaults to '7b'." ,
8689 )
8790 # output args:
@@ -416,7 +419,7 @@ def predict(
416419 vortex_style_fp8 = fp8 and not full_fp8 ,
417420 ** config_modifiers_init ,
418421 )
419- else : # mamba
422+ elif model_type == "mamba" : # mamba
420423 if model_size not in MAMBA_MODEL_OPTIONS :
421424 raise ValueError (f"Invalid model size for Mamba: { model_size } " )
422425 config = MAMBA_MODEL_OPTIONS [model_size ](
@@ -425,6 +428,15 @@ def predict(
425428 distribute_saved_activations = False if sequence_parallel and tensor_parallel_size > 1 else True ,
426429 ** config_modifiers_init ,
427430 )
431+ elif model_type == "gpt" :
432+ if model_size not in GPT_MODEL_OPTIONS :
433+ raise ValueError (f"Invalid model size for GPT: { model_size } " )
434+ config = GPT_MODEL_OPTIONS [model_size ](
435+ forward_step_fn = hyena_predict_forward_step ,
436+ data_step_fn = hyena_predict_data_step ,
437+ )
438+ else :
439+ raise ValueError (f"Invalid model type: { model_type } " )
428440
429441 trainer .strategy ._setup_optimizers = False
430442
@@ -451,13 +463,20 @@ def predict(
451463 output_log_prob_seqs = output_log_prob_seqs ,
452464 log_prob_collapse_option = log_prob_collapse_option ,
453465 )
454- else : # mamba
466+ elif model_type == "mamba" : # mamba
455467 model = MambaPredictor (
456468 config ,
457469 tokenizer = tokenizer ,
458470 output_log_prob_seqs = output_log_prob_seqs ,
459471 log_prob_collapse_option = log_prob_collapse_option ,
460472 )
473+ elif model_type == "gpt" :
474+ model = HyenaPredictor (
475+ config ,
476+ tokenizer = tokenizer ,
477+ output_log_prob_seqs = output_log_prob_seqs ,
478+ log_prob_collapse_option = log_prob_collapse_option ,
479+ )
461480
462481 resume .setup (trainer , model ) # this pulls weights from the starting checkpoint.
463482
0 commit comments