|
7 | 7 |
|
8 | 8 | import pytest |
9 | 9 |
|
10 | | -from fastsafetensors import SafeTensorsFileLoader, SafeTensorsMetadata, SingleGroup |
| 10 | +from fastsafetensors import ( |
| 11 | + ParallelLoader, |
| 12 | + SafeTensorsFileLoader, |
| 13 | + SafeTensorsMetadata, |
| 14 | + SingleGroup, |
| 15 | +) |
11 | 16 | from fastsafetensors import cpp as fstcpp |
12 | | -from fastsafetensors import fastsafe_open |
| 17 | +from fastsafetensors import ( |
| 18 | + fastsafe_open, |
| 19 | +) |
13 | 20 | from fastsafetensors.common import get_device_numa_node, is_gpu_found |
14 | 21 | from fastsafetensors.copier.gds import GdsFileCopier |
15 | 22 | from fastsafetensors.copier.nogds import NoGdsFileCopier |
@@ -531,7 +538,8 @@ def test_SafeTensorsFileLoader(fstcpp_log, input_files, framework) -> None: |
531 | 538 | assert bufs.get_filename(last_key) == input_files[0] |
532 | 539 | assert bufs.get_shape(last_key) == last_shape |
533 | 540 | assert loader.get_shape(last_key) == last_shape |
534 | | - assert bufs.get_filename("aaaaaaaaaaaaa") == "" |
| 541 | + with pytest.raises(ValueError): |
| 542 | + bufs.get_filename("aaaaaaaaaaaaa") |
535 | 543 | bufs.close() |
536 | 544 | loader.close() |
537 | 545 | assert framework.get_mem_used() == 0 |
@@ -560,6 +568,62 @@ def test_SafeTensorsFileLoaderNoGds(fstcpp_log, input_files, framework) -> None: |
560 | 568 | assert fstcpp.get_cpp_metrics().bounce_buffer_bytes == 0 |
561 | 569 |
|
562 | 570 |
|
| 571 | +def test_tensor_filter_hides_skipped_tensors(fstcpp_log, input_files, framework): |
| 572 | + device, _ = get_and_check_device(framework) |
| 573 | + meta = SafeTensorsMetadata.from_file(input_files[0], framework) |
| 574 | + |
| 575 | + kept = set(sorted(meta.tensors.keys())[::2]) |
| 576 | + keep = lambda name: name in kept # noqa: E731 |
| 577 | + skipped = next(name for name in meta.tensors if name not in kept) |
| 578 | + |
| 579 | + loader = SafeTensorsFileLoader( |
| 580 | + pg=SingleGroup(), |
| 581 | + device=device.as_str(), |
| 582 | + framework=framework.get_name(), |
| 583 | + nogds=True, |
| 584 | + ) |
| 585 | + loader.set_tensor_filter(keep) |
| 586 | + loader.add_filenames({0: [input_files[0]]}) |
| 587 | + bufs = loader.copy_files_to_device() |
| 588 | + |
| 589 | + assert set(loader.get_keys()) == kept |
| 590 | + assert skipped not in bufs.key_to_rank_lidx |
| 591 | + with pytest.raises(ValueError): |
| 592 | + bufs.get_tensor(skipped) |
| 593 | + with pytest.raises(ValueError): |
| 594 | + bufs.get_filename(skipped) |
| 595 | + with pytest.raises(ValueError): |
| 596 | + loader.get_shape(skipped) |
| 597 | + |
| 598 | + bufs.close() |
| 599 | + loader.close() |
| 600 | + |
| 601 | + |
| 602 | +def test_tensor_filter_iterate_weights_hides_skipped( |
| 603 | + fstcpp_log, input_files, framework |
| 604 | +): |
| 605 | + device, _ = get_and_check_device(framework) |
| 606 | + meta = SafeTensorsMetadata.from_file(input_files[0], framework) |
| 607 | + |
| 608 | + kept = set(sorted(meta.tensors.keys())[::2]) |
| 609 | + keep = lambda name: name in kept # noqa: E731 |
| 610 | + |
| 611 | + loader = ParallelLoader( |
| 612 | + pg=SingleGroup(), |
| 613 | + hf_weights_files=[input_files[0]], |
| 614 | + device=device.as_str(), |
| 615 | + nogds=True, |
| 616 | + framework=framework.get_name(), |
| 617 | + tensor_filter=keep, |
| 618 | + all_local=True, |
| 619 | + ) |
| 620 | + yielded = {key for key, _t in loader.iterate_weights()} |
| 621 | + assert yielded == kept |
| 622 | + |
| 623 | + loader.close() |
| 624 | + assert framework.get_mem_used() == 0 |
| 625 | + |
| 626 | + |
563 | 627 | def test_fastsafe_open(fstcpp_log, input_files, framework) -> None: |
564 | 628 | device, _ = get_and_check_device(framework) |
565 | 629 |
|
|
0 commit comments