3030from bionemo .core .data .load import load
3131from bionemo .llm .lightning import batch_collator
3232from bionemo .testing .data .fasta import ALU_SEQUENCE , create_fasta_file
33+ from bionemo .testing .torch import check_fp8_support
3334
3435
3536def is_a6000_gpu () -> bool :
@@ -74,11 +75,18 @@ def checkpoint_7b_1m_path() -> Path:
7475@pytest .mark .parametrize (
7576 "ddp,pp,tp,wi" ,
7677 [
77- (1 , 1 , 1 , "epoch" ),
78- (2 , 1 , 1 , "epoch" ),
79- (2 , 1 , 1 , "batch" ),
80- (1 , 2 , 1 , "epoch" ),
81- (1 , 1 , 2 , "epoch" ),
78+ pytest .param (1 , 1 , 1 , "epoch" , id = "ddp=1,pp=1,tp=1,wi=epoch" ),
79+ pytest .param (2 , 1 , 1 , "epoch" , id = "ddp=2,pp=1,tp=1,wi=epoch" ),
80+ pytest .param (2 , 1 , 1 , "batch" , id = "ddp=2,pp=1,tp=1,wi=batch" ),
81+ pytest .param (
82+ 1 ,
83+ 2 ,
84+ 1 ,
85+ "epoch" ,
86+ id = "ddp=1,pp=2,tp=1,wi=epoch" ,
87+ marks = pytest .mark .skip ("Pipeline parallelism test currently hangs." ),
88+ ),
89+ pytest .param (1 , 1 , 2 , "epoch" , id = "ddp=1,pp=1,tp=2,wi=epoch" ),
8290 ],
8391)
8492def test_predict_evo2_runs (
@@ -177,15 +185,29 @@ def test_predict_evo2_runs(
177185@pytest .mark .parametrize (
178186 "ddp,cp,pp,tp,fp8,wi" ,
179187 [
180- (1 , 1 , 1 , 1 , False , "epoch" ),
181- (2 , 1 , 1 , 1 , False , "epoch" ),
182- (2 , 1 , 1 , 1 , False , "batch" ), # simulate a large prediction run with dp parallelism
183- (1 , 2 , 1 , 1 , False , "epoch" ),
184- (1 , 2 , 1 , 1 , False , "batch" ),
185- (1 , 1 , 2 , 1 , False , "epoch" ),
186- (1 , 1 , 2 , 1 , True , "epoch" ), # Cover case where FP8 was not supported with TP=2
187- (1 , 1 , 1 , 2 , False , "epoch" ),
188+ pytest .param (1 , 1 , 1 , 1 , False , "epoch" , id = "ddp=1,cp=1,pp=1,tp=1,fp8=False,wi=epoch" ),
189+ pytest .param (2 , 1 , 1 , 1 , False , "epoch" , id = "ddp=2,cp=1,pp=1,tp=1,fp8=False,wi=epoch" ),
190+ pytest .param (
191+ 2 , 1 , 1 , 1 , False , "batch" , id = "ddp=2,cp=1,pp=1,tp=1,fp8=False,wi=batch"
192+ ), # simulate a large prediction run with dp parallelism
193+ pytest .param (1 , 2 , 1 , 1 , False , "epoch" , id = "ddp=1,cp=2,pp=1,tp=1,fp8=False,wi=epoch" ),
194+ pytest .param (1 , 2 , 1 , 1 , False , "batch" , id = "ddp=1,cp=2,pp=1,tp=1,fp8=False,wi=batch" ),
195+ pytest .param (
196+ 1 ,
197+ 1 ,
198+ 2 ,
199+ 1 ,
200+ False ,
201+ "epoch" ,
202+ id = "ddp=1,cp=1,pp=2,tp=1,fp8=False,wi=epoch" ,
203+ marks = pytest .mark .skip ("Pipeline parallelism test currently hangs." ),
204+ ),
205+ pytest .param (
206+ 1 , 1 , 1 , 2 , True , "epoch" , id = "ddp=1,cp=1,pp=1,tp=2,fp8=True,wi=epoch"
207+ ), # Cover case where FP8 was not supported with TP=2
208+ pytest .param (1 , 1 , 1 , 2 , False , "epoch" , id = "ddp=1,cp=1,pp=1,tp=2,fp8=False,wi=epoch" ),
188209 ],
210+ ids = lambda x : f"ddp={ x [0 ]} ,cp={ x [1 ]} ,pp={ x [2 ]} ,tp={ x [3 ]} ,fp8={ x [4 ]} ,wi={ x [5 ]} " ,
189211)
190212def test_predict_evo2_runs_with_log_probs (
191213 tmp_path ,
@@ -210,6 +232,9 @@ def test_predict_evo2_runs_with_log_probs(
210232 world_size = ddp * cp * pp * tp
211233 if world_size > torch .cuda .device_count ():
212234 pytest .skip (f"World size { world_size } is less than the number of GPUs { torch .cuda .device_count ()} " )
235+ is_fp8_supported , _ , _ = check_fp8_support (torch .cuda .current_device ())
236+ if not is_fp8_supported and fp8 :
237+ pytest .skip ("FP8 is not supported on this GPU." )
213238
214239 fasta_file_path = tmp_path / "test.fasta"
215240 create_fasta_file (
@@ -221,6 +246,7 @@ def test_predict_evo2_runs_with_log_probs(
221246 if is_a6000_gpu ():
222247 # Fix hanging issue on A6000 GPUs with multi-gpu tests
223248 env ["NCCL_P2P_DISABLE" ] = "1"
249+
224250 fp8_option = "--fp8" if fp8 else ""
225251 # Build the command string.
226252 # Note: The command assumes that `train_evo2` is in your PATH.
0 commit comments