@@ -468,18 +468,11 @@ def check_matchrate(*, ckpt_name, matchrate, assert_matchrate=True):
468468 raise NotImplementedError
469469
470470
471- # @pytest.mark.parametrize(
472- # "ckpt_name,expected_matchpercents",
473- # [
474- # ("evo2/1b-8k-bf16:1.0", [96.27, 67.93, 77.50, 80.30]),
475- # ("evo2/1b-8k:1.0", [96.27, 67.93, 77.50, 80.30]),
476- # ("evo2/7b-8k:1.0", [97.60, 89.63, 80.03, 84.57]),
477- # ("evo2/7b-1m:1.0", [97.60, 89.63, 80.03, 84.57]),
478- # ],
479- # )
480471@pytest .mark .parametrize (
481472 "ckpt_name,expected_matchpercents" ,
482473 [
474+ ("evo2/1b-8k-bf16:1.0" , [96.27 , 67.93 , 77.50 , 80.30 ]),
475+ ("evo2/1b-8k:1.0" , [96.27 , 67.93 , 77.50 , 80.30 ]),
483476 ("evo2/7b-8k:1.0" , [97.60 , 89.63 , 80.03 , 84.57 ]),
484477 ("evo2/7b-1m:1.0" , [97.60 , 89.63 , 80.03 , 84.57 ]),
485478 ],
@@ -538,20 +531,13 @@ def test_forward(sequences: list[str], ckpt_name: str, expected_matchpercents: l
538531@pytest .mark .parametrize (
539532 "ckpt_name,expected_matchpercents,flash_decode" ,
540533 [
534+ # Try flash decode with one and not the other to verify that both paths work.
535+ ("evo2/1b-8k-bf16:1.0" , [96.27 , 67.93 , 77.50 , 80.30 ], True ),
536+ ("evo2/1b-8k:1.0" , [96.27 , 67.93 , 77.50 , 80.30 ], False ),
541537 ("evo2/7b-8k:1.0" , [97.60 , 89.63 , 80.03 , 84.57 ], False ),
542538 ("evo2/7b-1m:1.0" , [97.60 , 89.63 , 80.03 , 84.57 ], False ),
543539 ],
544540)
545- # @pytest.mark.parametrize(
546- # "ckpt_name,expected_matchpercents,flash_decode",
547- # [
548- # # Try flash decode with one and not the other to verify that both paths work.
549- # ("evo2/1b-8k-bf16:1.0", [96.27, 67.93, 77.50, 80.30], True),
550- # ("evo2/1b-8k:1.0", [96.27, 67.93, 77.50, 80.30], False),
551- # ("evo2/7b-8k:1.0", [97.60, 89.63, 80.03, 84.57], False),
552- # ("evo2/7b-1m:1.0", [97.60, 89.63, 80.03, 84.57], False),
553- # ],
554- # )
555541def test_forward_manual (sequences : list [str ], ckpt_name : str , expected_matchpercents : list [float ], flash_decode : bool ):
556542 assert len (sequences ) > 0
557543 seq_len_cap = determine_memory_requirement_and_skip_if_not_met (
@@ -665,7 +651,7 @@ def calculate_sequence_identity(seq1: str, seq2: str) -> float | None:
665651 ("evo2/1b-8k:1.0" , get_model_and_tokenizer , [96.8 , 29.7 , 76.6 , 71.6 ]),
666652 ("evo2_mamba/7b-8k:0.1" , get_model_and_tokenizer_ignore_vortex , [99.2 , 51.0 , 73.0 , 82.6 ]),
667653 ("evo2/7b-8k:1.0" , get_model_and_tokenizer , [97.60 , 89.63 , 80.03 , 84.57 ]),
668- # ("evo2/7b-1m:1.0", get_model_and_tokenizer, [97.60, 89.63, 80.03, 84.57]),
654+ ("evo2/7b-1m:1.0" , get_model_and_tokenizer , [97.60 , 89.63 , 80.03 , 84.57 ]),
669655 ],
670656)
671657def test_batch_generate (
@@ -845,7 +831,7 @@ def test_batch_generate_coding_sequences(
845831 ("evo2/1b-8k:1.0" , get_model_and_tokenizer , 41.0 ),
846832 ("evo2_mamba/7b-8k:0.1" , get_model_and_tokenizer_ignore_vortex , 39.73 ),
847833 ("evo2/7b-8k:1.0" , get_model_and_tokenizer , 32.0 ),
848- # ("evo2/7b-1m:1.0", get_model_and_tokenizer, 32.0),
834+ ("evo2/7b-1m:1.0" , get_model_and_tokenizer , 32.0 ),
849835 ],
850836)
851837def test_generate_speed (
0 commit comments