@@ -429,13 +429,64 @@ def test_predict_evo2_equivalent_with_log_probs(
429429 assert log_probs .item () == pytest .approx (baseline_predictions_7b_1m_results [original_idx .item ()], rel = rel )
430430
431431
432- # Note: The PEFT/LoRA test is commented out as it requires training infrastructure and LoRA support
433- # which may need additional updates for the Megatron Bridge API
434- # @pytest.mark.timeout(512)
435- # @pytest.mark.slow
436- # def test_different_results_with_without_peft(tmp_path):
437- # """Test that predictions differ when using PEFT/LoRA adapters."""
438- # pass
432+ @pytest .mark .timeout (512 )
433+ @pytest .mark .slow
434+ def test_different_results_with_without_peft (tmp_path , mbridge_checkpoint_1b_8k_bf16_path ):
435+ """LoRA-finetune a few steps off the base ckpt, then predict on base vs. LoRA ckpt and assert logits differ."""
436+ num_steps = 2
437+ result_dir = tmp_path / "lora_finetune"
438+ env = copy .deepcopy (PRETEST_ENV )
439+ if is_a6000_gpu ():
440+ env ["NCCL_P2P_DISABLE" ] = "1"
441+
442+ ft_port = find_free_network_port ()
443+ ft_cmd = (
444+ f"torchrun --nproc-per-node 1 --no-python --master_port { ft_port } "
445+ f"train_evo2 --finetune-ckpt-dir { mbridge_checkpoint_1b_8k_bf16_path .parent } "
446+ f"--lora-finetune --lora-dim 8 --lora-alpha 16 "
447+ f"--lora-target-modules linear_qkv,linear_proj,linear_fc1,linear_fc2 "
448+ f"--hf-tokenizer-model-path { DEFAULT_HF_TOKENIZER_MODEL_PATH_512 } "
449+ f"--model-size evo2_1b_base --max-steps { num_steps } --eval-interval { num_steps } --eval-iters 1 "
450+ f"--mock-data --result-dir { result_dir } --mixed-precision-recipe bf16_mixed "
451+ f"--micro-batch-size 1 --global-batch-size 1 --seq-length 512 "
452+ f"--ckpt-format torch_dist --log-interval 1 --decay-steps 100 --warmup-steps 1 "
453+ f"--seed 42 --dataset-seed 33"
454+ )
455+ ft_result = subprocess .run (shlex .split (ft_cmd ), check = False , capture_output = True , text = True , cwd = tmp_path , env = env )
456+ assert ft_result .returncode == 0 , (
457+ f"LoRA finetune failed:\n STDOUT:\n { ft_result .stdout } \n STDERR:\n { ft_result .stderr } "
458+ )
459+ lora_ckpt = result_dir / "evo2" / "checkpoints" / f"iter_{ num_steps :07d} "
460+ assert lora_ckpt .exists (), f"Expected LoRA checkpoint at { lora_ckpt } "
461+
462+ fasta_file_path = tmp_path / "test.fasta"
463+ create_fasta_file (fasta_file_path , 3 , sequence_lengths = [32 , 65 , 129 ], repeating_dna_pattern = ALU_SEQUENCE )
464+
465+ def _run_predict (ckpt : Path , output_dir : Path ) -> None :
466+ port = find_free_network_port ()
467+ cmd = (
468+ f"torchrun --nproc_per_node 1 --nnodes 1 --master_port { port } "
469+ f"-m bionemo.evo2.run.predict --fasta { fasta_file_path } --ckpt-dir { ckpt } "
470+ f"--output-dir { output_dir } --micro-batch-size 3 --write-interval epoch "
471+ f"--pipeline-model-parallel-size 1 --num-nodes 1 --devices 1"
472+ )
473+ r = subprocess .run (shlex .split (cmd ), check = False , cwd = tmp_path , capture_output = True , text = True , env = env )
474+ assert r .returncode == 0 , f"predict_evo2 failed:\n STDOUT:\n { r .stdout } \n STDERR:\n { r .stderr } "
475+
476+ out_base = tmp_path / "out_base"
477+ out_lora = tmp_path / "out_lora"
478+ _run_predict (mbridge_checkpoint_1b_8k_bf16_path , out_base )
479+ _run_predict (lora_ckpt , out_lora )
480+
481+ base_files = glob .glob (str (out_base / "predictions__rank_*__dp_rank_*.pt" ))
482+ lora_files = glob .glob (str (out_lora / "predictions__rank_*__dp_rank_*.pt" ))
483+ assert len (base_files ) == 1 and len (lora_files ) == 1
484+
485+ base = torch .load (base_files [0 ], weights_only = False )
486+ lora = torch .load (lora_files [0 ], weights_only = False )
487+ assert torch .equal (base ["seq_idx" ], lora ["seq_idx" ])
488+ assert base ["token_logits" ].shape == lora ["token_logits" ].shape
489+ assert (base ["token_logits" ] != lora ["token_logits" ]).any (), "LoRA adapter had no effect on logits"
439490
440491
441492@pytest .mark .parametrize (
0 commit comments