|
48 | 48 | logger.setLevel(logging.DEBUG) # Capture all levels in the logger itself |
49 | 49 |
|
50 | 50 |
|
| 51 | +MEM_REQUIREMENT_1B_GB = 18 # add 0.6 GB to max mem reserved, and round up |
| 52 | +MEM_REQUIREMENT_7B_GB = 48 |
| 53 | + |
| 54 | + |
| 55 | + |
51 | 56 | def load_weights_sharded_inplace_nemo2_to_mcore( |
52 | 57 | model: MegatronModelType, |
53 | 58 | distributed_checkpoint_dir: str | Path, |
@@ -365,7 +370,7 @@ def check_matchrate(*, ckpt_name, matchrate, assert_matchrate=True): |
365 | 370 | def test_forward(sequences: list[str], ckpt_name: str, expected_matchpercents: list[float]): |
366 | 371 | assert len(sequences) > 0 |
367 | 372 | gb_available = torch.cuda.mem_get_info()[0] / 1024**3 |
368 | | - if (gb_available < 20 and "1b" in ckpt_name) or (gb_available < 40 and "7b" in ckpt_name): |
| 373 | + if (gb_available < MEM_REQUIREMENT_1B_GB and "1b" in ckpt_name) or (gb_available < MEM_REQUIREMENT_7B_GB and "7b" in ckpt_name): |
369 | 374 | pytest.skip( |
370 | 375 | f"Inference API requires more than 38GB of memory for 1b models, or 50GB for 7b models. {gb_available=}" |
371 | 376 | ) |
@@ -429,7 +434,7 @@ def test_forward_manual(sequences: list[str], ckpt_name: str, expected_matchperc |
429 | 434 | is_fp8_supported, compute_capability, device_info = check_fp8_support(torch.cuda.current_device()) |
430 | 435 | skip = "evo2/1b-8k:" in ckpt_name and not is_fp8_supported |
431 | 436 | gb_available = torch.cuda.mem_get_info()[0] / 1024**3 |
432 | | - if (gb_available < 20 and flash_decode) or (gb_available < 40 and flash_decode and "7b" in ckpt_name): |
| 437 | + if (gb_available < MEM_REQUIREMENT_1B_GB and flash_decode) or (gb_available < MEM_REQUIREMENT_7B_GB and flash_decode and "7b" in ckpt_name): |
433 | 438 | pytest.skip( |
434 | 439 | f"Inference API requires more than 38GB of memory for 1b models, or 50GB for 7b models. {gb_available=}" |
435 | 440 | ) |
@@ -544,7 +549,7 @@ def test_batch_generate( |
544 | 549 | assert len(sequences) > 0 |
545 | 550 | is_fp8_supported, compute_capability, device_info = check_fp8_support(torch.cuda.current_device()) |
546 | 551 | gb_available = torch.cuda.mem_get_info()[0] / 1024**3 |
547 | | - if (gb_available < 20 and "1b" in ckpt_name) or (gb_available < 40 and "7b" in ckpt_name): |
| 552 | + if (gb_available < MEM_REQUIREMENT_1B_GB and "1b" in ckpt_name) or (gb_available < MEM_REQUIREMENT_7B_GB and "7b" in ckpt_name): |
548 | 553 | pytest.skip( |
549 | 554 | f"Inference API requires more than 38GB of memory for 1b models, or 50GB for 7b models. {gb_available=}" |
550 | 555 | ) |
@@ -615,7 +620,7 @@ def test_batch_generate_coding_sequences( |
615 | 620 | ): |
616 | 621 | assert len(coding_sequences) > 0 |
617 | 622 | gb_available = torch.cuda.mem_get_info()[0] / 1024**3 |
618 | | - if (gb_available < 20 and "1b" in ckpt_name) or (gb_available < 40 and "7b" in ckpt_name): |
| 623 | + if (gb_available < MEM_REQUIREMENT_1B_GB and "1b" in ckpt_name) or (gb_available < MEM_REQUIREMENT_7B_GB and "7b" in ckpt_name): |
619 | 624 | pytest.skip( |
620 | 625 | f"Inference API requires more than 38GB of memory for 1b models, or 50GB for 7b models. {gb_available=}" |
621 | 626 | ) |
@@ -724,7 +729,7 @@ def test_generate_speed( |
724 | 729 | ): |
725 | 730 | is_fp8_supported, compute_capability, device_info = check_fp8_support(torch.cuda.current_device()) |
726 | 731 | gb_available = torch.cuda.mem_get_info()[0] / 1024**3 |
727 | | - if (gb_available < 20 and "1b" in ckpt_name) or (gb_available < 40 and "7b" in ckpt_name): |
| 732 | + if (gb_available < MEM_REQUIREMENT_1B_GB and "1b" in ckpt_name) or (gb_available < MEM_REQUIREMENT_7B_GB and "7b" in ckpt_name): |
728 | 733 | pytest.skip( |
729 | 734 | f"Inference API requires more than 38GB of memory for 1b models, or 50GB for 7b models. {gb_available=}" |
730 | 735 | ) |
|
0 commit comments