|
8 | 8 | # pyre-strict |
9 | 9 |
|
10 | 10 | import unittest |
| 11 | +from unittest.mock import patch |
11 | 12 |
|
12 | 13 | import torch |
13 | 14 | from fbgemm_gpu.split_table_batched_embeddings_ops_common import ( |
14 | 15 | ComputeDevice, |
15 | 16 | EmbeddingLocation, |
16 | 17 | ) |
17 | 18 | from fbgemm_gpu.split_table_batched_embeddings_ops_training import ( |
| 19 | + RESParams, |
18 | 20 | SplitTableBatchedEmbeddingBagsCodegen, |
19 | 21 | ) |
20 | 22 |
|
@@ -627,6 +629,226 @@ def test_get_prefetched_info_with_neither(self) -> None: |
627 | 629 | self.assertIsNone(prefetched_info.hash_zch_identities) |
628 | 630 | self.assertIsNone(prefetched_info.hash_zch_runtime_meta) |
629 | 631 |
|
| 632 | + @unittest.skipIf(*gpu_unavailable) |
| 633 | + def test_register_res_buffers_default_dim(self) -> None: |
| 634 | + """ |
| 635 | + Test that RES buffers are registered with default dim=1. |
| 636 | + """ |
| 637 | + res_params = RESParams( |
| 638 | + res_store_shards=1, |
| 639 | + table_names=["table_0"], |
| 640 | + table_offsets=[0, 100], |
| 641 | + table_sizes=[100], |
| 642 | + ) |
| 643 | + with patch( |
| 644 | + "fbgemm_gpu.split_table_batched_embeddings_ops_training.torch.classes.fbgemm.RawEmbeddingStreamer" |
| 645 | + ): |
| 646 | + tbe = SplitTableBatchedEmbeddingBagsCodegen( |
| 647 | + embedding_specs=[ |
| 648 | + (100, 16, EmbeddingLocation.MANAGED_CACHING, ComputeDevice.CUDA), |
| 649 | + ], |
| 650 | + enable_raw_embedding_streaming=True, |
| 651 | + res_params=res_params, |
| 652 | + ) |
| 653 | + cache_size = tbe.lxu_cache_weights.size(0) |
| 654 | + self.assertGreater(cache_size, 0) |
| 655 | + self.assertEqual(tbe.res_runtime_meta.shape, (cache_size, 1)) |
| 656 | + |
| 657 | + @unittest.skipIf(*gpu_unavailable) |
| 658 | + def test_register_empty_res_buffers_default_dim(self) -> None: |
| 659 | + """ |
| 660 | + Test that empty RES buffers have dim=1 when streaming is disabled. |
| 661 | + """ |
| 662 | + tbe = SplitTableBatchedEmbeddingBagsCodegen( |
| 663 | + embedding_specs=[ |
| 664 | + (100, 16, EmbeddingLocation.MANAGED_CACHING, ComputeDevice.CUDA), |
| 665 | + ], |
| 666 | + enable_raw_embedding_streaming=False, |
| 667 | + ) |
| 668 | + self.assertEqual(tbe.res_runtime_meta.shape[1], 1) |
| 669 | + |
| 670 | + @unittest.skipIf(*gpu_unavailable) |
| 671 | + def test_lazy_resize_runtime_meta(self) -> None: |
| 672 | + """ |
| 673 | + Test that lazy resize in raw_embedding_stream() resizes res_runtime_meta |
| 674 | + buffer when actual data has a different dim or dtype than the default. |
| 675 | + """ |
| 676 | + res_params = RESParams( |
| 677 | + res_store_shards=1, |
| 678 | + table_names=["table_0"], |
| 679 | + table_offsets=[0, 100], |
| 680 | + table_sizes=[100], |
| 681 | + ) |
| 682 | + with patch( |
| 683 | + "fbgemm_gpu.split_table_batched_embeddings_ops_training.torch.classes.fbgemm.RawEmbeddingStreamer" |
| 684 | + ): |
| 685 | + tbe = SplitTableBatchedEmbeddingBagsCodegen( |
| 686 | + embedding_specs=[ |
| 687 | + (100, 16, EmbeddingLocation.MANAGED_CACHING, ComputeDevice.CUDA), |
| 688 | + ], |
| 689 | + enable_raw_embedding_streaming=True, |
| 690 | + res_params=res_params, |
| 691 | + ) |
| 692 | + cache_size = tbe.lxu_cache_weights.size(0) |
| 693 | + # Initially dim=1 |
| 694 | + self.assertEqual(tbe.res_runtime_meta.shape, (cache_size, 1)) |
| 695 | + |
| 696 | + # Simulate runtime_meta with dim=2 arriving via prefetch |
| 697 | + n = 4 |
| 698 | + runtime_meta_data = torch.tensor( |
| 699 | + [[1, 10], [2, 20], [3, 30], [4, 40]], |
| 700 | + device=torch.cuda.current_device(), |
| 701 | + dtype=torch.int64, |
| 702 | + ) |
| 703 | + |
| 704 | + # Manually trigger the resize logic |
| 705 | + data = runtime_meta_data |
| 706 | + if ( |
| 707 | + data.shape[1] != tbe.res_runtime_meta.shape[1] |
| 708 | + or data.dtype != tbe.res_runtime_meta.dtype |
| 709 | + ): |
| 710 | + tbe.register_buffer( |
| 711 | + "res_runtime_meta", |
| 712 | + torch.ops.fbgemm.new_unified_tensor( |
| 713 | + torch.zeros(1, device=tbe.current_device, dtype=data.dtype), |
| 714 | + (tbe.res_runtime_meta.shape[0], data.shape[1]), |
| 715 | + is_host_mapped=tbe.uvm_host_mapped, |
| 716 | + ), |
| 717 | + persistent=False, |
| 718 | + ) |
| 719 | + |
| 720 | + # After resize, dim should be 2 |
| 721 | + self.assertEqual(tbe.res_runtime_meta.shape, (cache_size, 2)) |
| 722 | + # Copy should succeed |
| 723 | + tbe.res_runtime_meta[:n].copy_(runtime_meta_data) |
| 724 | + self.assertEqual( |
| 725 | + runtime_meta_data.tolist(), |
| 726 | + tbe.res_runtime_meta[:n].tolist(), |
| 727 | + ) |
| 728 | + |
| 729 | + @unittest.skipIf(*gpu_unavailable) |
| 730 | + def test_res_runtime_meta_not_in_state_dict(self) -> None: |
| 731 | + """ |
| 732 | + Test that res_runtime_meta is registered with persistent=False and |
| 733 | + does not appear in state_dict() (shape changes with runtime_meta_dim). |
| 734 | + """ |
| 735 | + res_params = RESParams( |
| 736 | + res_store_shards=1, |
| 737 | + table_names=["table_0"], |
| 738 | + table_offsets=[0, 100], |
| 739 | + table_sizes=[100], |
| 740 | + ) |
| 741 | + with patch( |
| 742 | + "fbgemm_gpu.split_table_batched_embeddings_ops_training.torch.classes.fbgemm.RawEmbeddingStreamer" |
| 743 | + ): |
| 744 | + tbe = SplitTableBatchedEmbeddingBagsCodegen( |
| 745 | + embedding_specs=[ |
| 746 | + (100, 16, EmbeddingLocation.MANAGED_CACHING, ComputeDevice.CUDA), |
| 747 | + ], |
| 748 | + enable_raw_embedding_streaming=True, |
| 749 | + res_params=res_params, |
| 750 | + ) |
| 751 | + state_dict = tbe.state_dict() |
| 752 | + self.assertNotIn( |
| 753 | + "res_runtime_meta", |
| 754 | + state_dict, |
| 755 | + "res_runtime_meta should not be in state_dict", |
| 756 | + ) |
| 757 | + |
| 758 | + @unittest.skipIf(*gpu_unavailable) |
| 759 | + def test_prefetched_info_with_multi_dim_runtime_meta(self) -> None: |
| 760 | + """ |
| 761 | + Test that _get_prefetched_info preserves multi-dimensional runtime_meta. |
| 762 | + When runtime_meta has shape [N, 2], output should also have dim=2. |
| 763 | + """ |
| 764 | + hash_zch_runtime_meta = torch.tensor( |
| 765 | + [ |
| 766 | + [1, 10], |
| 767 | + [2, 20], |
| 768 | + [3, 30], |
| 769 | + [4, 40], |
| 770 | + ], |
| 771 | + device=torch.cuda.current_device(), |
| 772 | + dtype=torch.int64, |
| 773 | + ) |
| 774 | + total_cache_hash_size = 100 |
| 775 | + linear_cache_indices_merged = torch.tensor( |
| 776 | + [54, 27, 43, 90], |
| 777 | + device=torch.cuda.current_device(), |
| 778 | + dtype=torch.int64, |
| 779 | + ) |
| 780 | + |
| 781 | + prefetched_info = SplitTableBatchedEmbeddingBagsCodegen._get_prefetched_info( |
| 782 | + linear_indices=linear_cache_indices_merged, |
| 783 | + linear_cache_indices_merged=linear_cache_indices_merged, |
| 784 | + total_cache_hash_size=total_cache_hash_size, |
| 785 | + hash_zch_identities=None, |
| 786 | + hash_zch_runtime_meta=hash_zch_runtime_meta, |
| 787 | + max_indices_length=200, |
| 788 | + ) |
| 789 | + |
| 790 | + assert prefetched_info.hash_zch_runtime_meta is not None |
| 791 | + self.assertEqual(prefetched_info.hash_zch_runtime_meta.shape[1], 2) |
| 792 | + self.assertEqual(prefetched_info.hash_zch_runtime_meta.shape[0], 4) |
| 793 | + # Verify sorted order (by cache index: 27, 43, 54, 90) |
| 794 | + self.assertEqual( |
| 795 | + [ |
| 796 | + [2, 20], # runtime meta for index 27 |
| 797 | + [3, 30], # runtime meta for index 43 |
| 798 | + [1, 10], # runtime meta for index 54 |
| 799 | + [4, 40], # runtime meta for index 90 |
| 800 | + ], |
| 801 | + prefetched_info.hash_zch_runtime_meta.tolist(), |
| 802 | + ) |
| 803 | + |
| 804 | + @unittest.skipIf(*gpu_unavailable) |
| 805 | + def test_copy_runtime_meta_none_skipped(self) -> None: |
| 806 | + """ |
| 807 | + Test that when hash_zch_runtime_meta is None in prefetched_info, |
| 808 | + the copy to res_runtime_meta is skipped without crashing. |
| 809 | + """ |
| 810 | + res_params = RESParams( |
| 811 | + res_store_shards=1, |
| 812 | + table_names=["table_0"], |
| 813 | + table_offsets=[0, 100], |
| 814 | + table_sizes=[100], |
| 815 | + ) |
| 816 | + with patch( |
| 817 | + "fbgemm_gpu.split_table_batched_embeddings_ops_training.torch.classes.fbgemm.RawEmbeddingStreamer" |
| 818 | + ): |
| 819 | + tbe = SplitTableBatchedEmbeddingBagsCodegen( |
| 820 | + embedding_specs=[ |
| 821 | + (100, 16, EmbeddingLocation.MANAGED_CACHING, ComputeDevice.CUDA), |
| 822 | + ], |
| 823 | + enable_raw_embedding_streaming=True, |
| 824 | + res_params=res_params, |
| 825 | + ) |
| 826 | + |
| 827 | + # Store a prefetched_info with runtime_meta=None |
| 828 | + indices = torch.tensor( |
| 829 | + [1, 2, 3], device=torch.cuda.current_device(), dtype=torch.int64 |
| 830 | + ) |
| 831 | + offsets = torch.tensor( |
| 832 | + [0, 3], device=torch.cuda.current_device(), dtype=torch.int64 |
| 833 | + ) |
| 834 | + linear_cache_indices_merged = torch.tensor( |
| 835 | + [1, 2, 3], device=torch.cuda.current_device(), dtype=torch.int64 |
| 836 | + ) |
| 837 | + |
| 838 | + # This should not crash even though runtime_meta is None |
| 839 | + tbe._store_prefetched_tensors( |
| 840 | + indices=indices, |
| 841 | + offsets=offsets, |
| 842 | + vbe_metadata=None, |
| 843 | + linear_cache_indices_merged=linear_cache_indices_merged, |
| 844 | + final_lxu_cache_locations=torch.ones_like(linear_cache_indices_merged), |
| 845 | + hash_zch_identities=None, |
| 846 | + hash_zch_runtime_meta=None, |
| 847 | + ) |
| 848 | + |
| 849 | + self.assertEqual(len(tbe.prefetched_info_list), 1) |
| 850 | + self.assertIsNone(tbe.prefetched_info_list[0].hash_zch_runtime_meta) |
| 851 | + |
630 | 852 |
|
631 | 853 | if __name__ == "__main__": |
632 | 854 | unittest.main() |
0 commit comments