Skip to content

Commit 5f4ce5d

Browse files
committed
br: uncomment, linter
Signed-off-by: Brian Roland <broland@nvidia.com>
1 parent 34da92b commit 5f4ce5d

1 file changed

Lines changed: 7 additions & 21 deletions

File tree

sub-packages/bionemo-evo2/tests/bionemo/evo2/test_evo2.py

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -468,18 +468,11 @@ def check_matchrate(*, ckpt_name, matchrate, assert_matchrate=True):
468468
raise NotImplementedError
469469

470470

471-
# @pytest.mark.parametrize(
472-
# "ckpt_name,expected_matchpercents",
473-
# [
474-
# ("evo2/1b-8k-bf16:1.0", [96.27, 67.93, 77.50, 80.30]),
475-
# ("evo2/1b-8k:1.0", [96.27, 67.93, 77.50, 80.30]),
476-
# ("evo2/7b-8k:1.0", [97.60, 89.63, 80.03, 84.57]),
477-
# ("evo2/7b-1m:1.0", [97.60, 89.63, 80.03, 84.57]),
478-
# ],
479-
# )
480471
@pytest.mark.parametrize(
481472
"ckpt_name,expected_matchpercents",
482473
[
474+
("evo2/1b-8k-bf16:1.0", [96.27, 67.93, 77.50, 80.30]),
475+
("evo2/1b-8k:1.0", [96.27, 67.93, 77.50, 80.30]),
483476
("evo2/7b-8k:1.0", [97.60, 89.63, 80.03, 84.57]),
484477
("evo2/7b-1m:1.0", [97.60, 89.63, 80.03, 84.57]),
485478
],
@@ -538,20 +531,13 @@ def test_forward(sequences: list[str], ckpt_name: str, expected_matchpercents: l
538531
@pytest.mark.parametrize(
539532
"ckpt_name,expected_matchpercents,flash_decode",
540533
[
534+
# Try flash decode with one and not the other to verify that both paths work.
535+
("evo2/1b-8k-bf16:1.0", [96.27, 67.93, 77.50, 80.30], True),
536+
("evo2/1b-8k:1.0", [96.27, 67.93, 77.50, 80.30], False),
541537
("evo2/7b-8k:1.0", [97.60, 89.63, 80.03, 84.57], False),
542538
("evo2/7b-1m:1.0", [97.60, 89.63, 80.03, 84.57], False),
543539
],
544540
)
545-
# @pytest.mark.parametrize(
546-
# "ckpt_name,expected_matchpercents,flash_decode",
547-
# [
548-
# # Try flash decode with one and not the other to verify that both paths work.
549-
# ("evo2/1b-8k-bf16:1.0", [96.27, 67.93, 77.50, 80.30], True),
550-
# ("evo2/1b-8k:1.0", [96.27, 67.93, 77.50, 80.30], False),
551-
# ("evo2/7b-8k:1.0", [97.60, 89.63, 80.03, 84.57], False),
552-
# ("evo2/7b-1m:1.0", [97.60, 89.63, 80.03, 84.57], False),
553-
# ],
554-
# )
555541
def test_forward_manual(sequences: list[str], ckpt_name: str, expected_matchpercents: list[float], flash_decode: bool):
556542
assert len(sequences) > 0
557543
seq_len_cap = determine_memory_requirement_and_skip_if_not_met(
@@ -665,7 +651,7 @@ def calculate_sequence_identity(seq1: str, seq2: str) -> float | None:
665651
("evo2/1b-8k:1.0", get_model_and_tokenizer, [96.8, 29.7, 76.6, 71.6]),
666652
("evo2_mamba/7b-8k:0.1", get_model_and_tokenizer_ignore_vortex, [99.2, 51.0, 73.0, 82.6]),
667653
("evo2/7b-8k:1.0", get_model_and_tokenizer, [97.60, 89.63, 80.03, 84.57]),
668-
# ("evo2/7b-1m:1.0", get_model_and_tokenizer, [97.60, 89.63, 80.03, 84.57]),
654+
("evo2/7b-1m:1.0", get_model_and_tokenizer, [97.60, 89.63, 80.03, 84.57]),
669655
],
670656
)
671657
def test_batch_generate(
@@ -845,7 +831,7 @@ def test_batch_generate_coding_sequences(
845831
("evo2/1b-8k:1.0", get_model_and_tokenizer, 41.0),
846832
("evo2_mamba/7b-8k:0.1", get_model_and_tokenizer_ignore_vortex, 39.73),
847833
("evo2/7b-8k:1.0", get_model_and_tokenizer, 32.0),
848-
# ("evo2/7b-1m:1.0", get_model_and_tokenizer, 32.0),
834+
("evo2/7b-1m:1.0", get_model_and_tokenizer, 32.0),
849835
],
850836
)
851837
def test_generate_speed(

0 commit comments

Comments
 (0)