Skip to content

Commit ffdfb0e

Browse files
Joey Yangmeta-codesync[bot]
authored andcommitted
Support multi-dimensional runtime_meta in RES streaming buffers by lazy init (pytorch#5643)
Summary: Pull Request resolved: pytorch#5643 X-link: https://github.com/facebookresearch/FBGEMM/pull/2591 The `res_runtime_meta` buffer in `SplitTableBatchedEmbeddingBagsCodegen` was hardcoded to shape `(cache_size, 1)`. When `_hash_zch_runtime_meta` has dim > 1 (e.g., feature cache storing 2 cached features via `zch_custom_runtime_meta_dim=2`), the `.copy_()` in `raw_embedding_stream()` crashes with: `RuntimeError: output with shape [N, 1] doesn't match the broadcast shape [N, 2]` Full output P2274437578 This diff lazy resizes the buffer defaults to `(cache_size, 1, torch.long)` and auto-corrects on the first iteration when runtime_meta data arrives with a different shape or dtype. This is a one-time operation, after the first resize, dims match and no further reallocation occurs. Reviewed By: chouxi Differential Revision: D100944325 fbshipit-source-id: 06157e6a7b2a72c4ced847bc9f322e1222924524
1 parent 505cb4d commit ffdfb0e

3 files changed

Lines changed: 258 additions & 6 deletions

File tree

fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1544,8 +1544,8 @@ def _register_res_buffers(self) -> None:
15441544
self.enable_raw_embedding_streaming
15451545
), "Should not register res buffers when raw embedding streaming is not enabled"
15461546
cache_size = self.lxu_cache_weights.size(0)
1547+
self.log(f"[RES] registering buffers: cache_size={cache_size}")
15471548
if cache_size == 0:
1548-
self.log("Registering empty res buffers when there is no cache")
15491549
self._register_empty_res_buffers()
15501550
return
15511551
self.register_buffer(
@@ -1596,6 +1596,7 @@ def _register_res_buffers(self) -> None:
15961596
(cache_size, 1),
15971597
is_host_mapped=self.uvm_host_mapped,
15981598
),
1599+
persistent=False, # shape may change via lazy resize, exclude from checkpoints
15991600
)
16001601
self.register_buffer(
16011602
"res_count",
@@ -1645,6 +1646,7 @@ def _register_empty_res_buffers(self) -> None:
16451646
self.register_buffer(
16461647
"res_runtime_meta",
16471648
torch.zeros(0, 1, device=self.current_device, dtype=torch.long),
1649+
persistent=False, # shape may change via lazy resize, exclude from checkpoints
16481650
)
16491651
self.register_buffer(
16501652
"res_count",
@@ -4404,10 +4406,38 @@ def raw_embedding_stream(self) -> None:
44044406
: prefetched_info.hash_zch_identities.size(0)
44054407
].copy_(prefetched_info.hash_zch_identities)
44064408
if prefetched_info.hash_zch_runtime_meta is not None:
4407-
# pyre-ignore[29]: `Union[...]` is not a function.
4408-
self.res_runtime_meta[
4409-
: prefetched_info.hash_zch_runtime_meta.size(0)
4410-
].copy_(prefetched_info.hash_zch_runtime_meta)
4409+
runtime_meta = prefetched_info.hash_zch_runtime_meta
4410+
if runtime_meta.dim() != 2:
4411+
self.log(
4412+
f"[RES] unexpected runtime_meta rank: {runtime_meta.dim()}, expected 2, skipping"
4413+
)
4414+
else:
4415+
if (
4416+
runtime_meta.shape[1] != self.res_runtime_meta.shape[1]
4417+
or runtime_meta.dtype != self.res_runtime_meta.dtype
4418+
):
4419+
self.log(
4420+
f"[RES] lazy resize runtime_meta: {self.res_runtime_meta.shape} -> ({self.res_runtime_meta.shape[0]}, {runtime_meta.shape[1]}), dtype {self.res_runtime_meta.dtype} -> {runtime_meta.dtype}"
4421+
)
4422+
# Lazy resize: runtime_meta shape/dtype is not known until
4423+
# the first data arrives from the MC module. Must use UVM
4424+
# (new_unified_tensor) because the C++ RawEmbeddingStreamer
4425+
# reads this buffer via raw CPU pointers in tensor_copy().
4426+
self.register_buffer(
4427+
"res_runtime_meta",
4428+
torch.ops.fbgemm.new_unified_tensor(
4429+
torch.zeros(
4430+
1,
4431+
device=self.current_device,
4432+
dtype=runtime_meta.dtype,
4433+
),
4434+
(self.res_runtime_meta.shape[0], runtime_meta.shape[1]),
4435+
is_host_mapped=self.uvm_host_mapped,
4436+
),
4437+
persistent=False, # shape may change via lazy resize, exclude from checkpoints
4438+
)
4439+
# pyre-ignore[29]: `Union[...]` is not a function.
4440+
self.res_runtime_meta[: runtime_meta.size(0)].copy_(runtime_meta)
44114441

44124442
self.res_copy_done.fill_(1)
44134443

fbgemm_gpu/src/split_embeddings_cache/raw_embedding_streamer.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ fbgemm_gpu::StreamQueueItem tensor_copy(
122122
});
123123
}
124124
if (runtime_meta.has_value()) {
125-
FBGEMM_DISPATCH_INTEGRAL_TYPES(
125+
FBGEMM_DISPATCH_ALL_TYPES(
126126
runtime_meta->scalar_type(), "tensor_copy", [&] {
127127
using runtime_meta_t = scalar_t;
128128
auto runtime_meta_addr =

fbgemm_gpu/test/tbe/training/store_prefetched_tensors_test.py

Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,15 @@
88
# pyre-strict
99

1010
import unittest
11+
from unittest.mock import patch
1112

1213
import torch
1314
from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
1415
ComputeDevice,
1516
EmbeddingLocation,
1617
)
1718
from fbgemm_gpu.split_table_batched_embeddings_ops_training import (
19+
RESParams,
1820
SplitTableBatchedEmbeddingBagsCodegen,
1921
)
2022

@@ -627,6 +629,226 @@ def test_get_prefetched_info_with_neither(self) -> None:
627629
self.assertIsNone(prefetched_info.hash_zch_identities)
628630
self.assertIsNone(prefetched_info.hash_zch_runtime_meta)
629631

632+
@unittest.skipIf(*gpu_unavailable)
633+
def test_register_res_buffers_default_dim(self) -> None:
634+
"""
635+
Test that RES buffers are registered with default dim=1.
636+
"""
637+
res_params = RESParams(
638+
res_store_shards=1,
639+
table_names=["table_0"],
640+
table_offsets=[0, 100],
641+
table_sizes=[100],
642+
)
643+
with patch(
644+
"fbgemm_gpu.split_table_batched_embeddings_ops_training.torch.classes.fbgemm.RawEmbeddingStreamer"
645+
):
646+
tbe = SplitTableBatchedEmbeddingBagsCodegen(
647+
embedding_specs=[
648+
(100, 16, EmbeddingLocation.MANAGED_CACHING, ComputeDevice.CUDA),
649+
],
650+
enable_raw_embedding_streaming=True,
651+
res_params=res_params,
652+
)
653+
cache_size = tbe.lxu_cache_weights.size(0)
654+
self.assertGreater(cache_size, 0)
655+
self.assertEqual(tbe.res_runtime_meta.shape, (cache_size, 1))
656+
657+
@unittest.skipIf(*gpu_unavailable)
658+
def test_register_empty_res_buffers_default_dim(self) -> None:
659+
"""
660+
Test that empty RES buffers have dim=1 when streaming is disabled.
661+
"""
662+
tbe = SplitTableBatchedEmbeddingBagsCodegen(
663+
embedding_specs=[
664+
(100, 16, EmbeddingLocation.MANAGED_CACHING, ComputeDevice.CUDA),
665+
],
666+
enable_raw_embedding_streaming=False,
667+
)
668+
self.assertEqual(tbe.res_runtime_meta.shape[1], 1)
669+
670+
@unittest.skipIf(*gpu_unavailable)
671+
def test_lazy_resize_runtime_meta(self) -> None:
672+
"""
673+
Test that lazy resize in raw_embedding_stream() resizes res_runtime_meta
674+
buffer when actual data has a different dim or dtype than the default.
675+
"""
676+
res_params = RESParams(
677+
res_store_shards=1,
678+
table_names=["table_0"],
679+
table_offsets=[0, 100],
680+
table_sizes=[100],
681+
)
682+
with patch(
683+
"fbgemm_gpu.split_table_batched_embeddings_ops_training.torch.classes.fbgemm.RawEmbeddingStreamer"
684+
):
685+
tbe = SplitTableBatchedEmbeddingBagsCodegen(
686+
embedding_specs=[
687+
(100, 16, EmbeddingLocation.MANAGED_CACHING, ComputeDevice.CUDA),
688+
],
689+
enable_raw_embedding_streaming=True,
690+
res_params=res_params,
691+
)
692+
cache_size = tbe.lxu_cache_weights.size(0)
693+
# Initially dim=1
694+
self.assertEqual(tbe.res_runtime_meta.shape, (cache_size, 1))
695+
696+
# Simulate runtime_meta with dim=2 arriving via prefetch
697+
n = 4
698+
runtime_meta_data = torch.tensor(
699+
[[1, 10], [2, 20], [3, 30], [4, 40]],
700+
device=torch.cuda.current_device(),
701+
dtype=torch.int64,
702+
)
703+
704+
# Manually trigger the resize logic
705+
data = runtime_meta_data
706+
if (
707+
data.shape[1] != tbe.res_runtime_meta.shape[1]
708+
or data.dtype != tbe.res_runtime_meta.dtype
709+
):
710+
tbe.register_buffer(
711+
"res_runtime_meta",
712+
torch.ops.fbgemm.new_unified_tensor(
713+
torch.zeros(1, device=tbe.current_device, dtype=data.dtype),
714+
(tbe.res_runtime_meta.shape[0], data.shape[1]),
715+
is_host_mapped=tbe.uvm_host_mapped,
716+
),
717+
persistent=False,
718+
)
719+
720+
# After resize, dim should be 2
721+
self.assertEqual(tbe.res_runtime_meta.shape, (cache_size, 2))
722+
# Copy should succeed
723+
tbe.res_runtime_meta[:n].copy_(runtime_meta_data)
724+
self.assertEqual(
725+
runtime_meta_data.tolist(),
726+
tbe.res_runtime_meta[:n].tolist(),
727+
)
728+
729+
@unittest.skipIf(*gpu_unavailable)
730+
def test_res_runtime_meta_not_in_state_dict(self) -> None:
731+
"""
732+
Test that res_runtime_meta is registered with persistent=False and
733+
does not appear in state_dict() (shape changes with runtime_meta_dim).
734+
"""
735+
res_params = RESParams(
736+
res_store_shards=1,
737+
table_names=["table_0"],
738+
table_offsets=[0, 100],
739+
table_sizes=[100],
740+
)
741+
with patch(
742+
"fbgemm_gpu.split_table_batched_embeddings_ops_training.torch.classes.fbgemm.RawEmbeddingStreamer"
743+
):
744+
tbe = SplitTableBatchedEmbeddingBagsCodegen(
745+
embedding_specs=[
746+
(100, 16, EmbeddingLocation.MANAGED_CACHING, ComputeDevice.CUDA),
747+
],
748+
enable_raw_embedding_streaming=True,
749+
res_params=res_params,
750+
)
751+
state_dict = tbe.state_dict()
752+
self.assertNotIn(
753+
"res_runtime_meta",
754+
state_dict,
755+
"res_runtime_meta should not be in state_dict",
756+
)
757+
758+
@unittest.skipIf(*gpu_unavailable)
759+
def test_prefetched_info_with_multi_dim_runtime_meta(self) -> None:
760+
"""
761+
Test that _get_prefetched_info preserves multi-dimensional runtime_meta.
762+
When runtime_meta has shape [N, 2], output should also have dim=2.
763+
"""
764+
hash_zch_runtime_meta = torch.tensor(
765+
[
766+
[1, 10],
767+
[2, 20],
768+
[3, 30],
769+
[4, 40],
770+
],
771+
device=torch.cuda.current_device(),
772+
dtype=torch.int64,
773+
)
774+
total_cache_hash_size = 100
775+
linear_cache_indices_merged = torch.tensor(
776+
[54, 27, 43, 90],
777+
device=torch.cuda.current_device(),
778+
dtype=torch.int64,
779+
)
780+
781+
prefetched_info = SplitTableBatchedEmbeddingBagsCodegen._get_prefetched_info(
782+
linear_indices=linear_cache_indices_merged,
783+
linear_cache_indices_merged=linear_cache_indices_merged,
784+
total_cache_hash_size=total_cache_hash_size,
785+
hash_zch_identities=None,
786+
hash_zch_runtime_meta=hash_zch_runtime_meta,
787+
max_indices_length=200,
788+
)
789+
790+
assert prefetched_info.hash_zch_runtime_meta is not None
791+
self.assertEqual(prefetched_info.hash_zch_runtime_meta.shape[1], 2)
792+
self.assertEqual(prefetched_info.hash_zch_runtime_meta.shape[0], 4)
793+
# Verify sorted order (by cache index: 27, 43, 54, 90)
794+
self.assertEqual(
795+
[
796+
[2, 20], # runtime meta for index 27
797+
[3, 30], # runtime meta for index 43
798+
[1, 10], # runtime meta for index 54
799+
[4, 40], # runtime meta for index 90
800+
],
801+
prefetched_info.hash_zch_runtime_meta.tolist(),
802+
)
803+
804+
@unittest.skipIf(*gpu_unavailable)
805+
def test_copy_runtime_meta_none_skipped(self) -> None:
806+
"""
807+
Test that when hash_zch_runtime_meta is None in prefetched_info,
808+
the copy to res_runtime_meta is skipped without crashing.
809+
"""
810+
res_params = RESParams(
811+
res_store_shards=1,
812+
table_names=["table_0"],
813+
table_offsets=[0, 100],
814+
table_sizes=[100],
815+
)
816+
with patch(
817+
"fbgemm_gpu.split_table_batched_embeddings_ops_training.torch.classes.fbgemm.RawEmbeddingStreamer"
818+
):
819+
tbe = SplitTableBatchedEmbeddingBagsCodegen(
820+
embedding_specs=[
821+
(100, 16, EmbeddingLocation.MANAGED_CACHING, ComputeDevice.CUDA),
822+
],
823+
enable_raw_embedding_streaming=True,
824+
res_params=res_params,
825+
)
826+
827+
# Store a prefetched_info with runtime_meta=None
828+
indices = torch.tensor(
829+
[1, 2, 3], device=torch.cuda.current_device(), dtype=torch.int64
830+
)
831+
offsets = torch.tensor(
832+
[0, 3], device=torch.cuda.current_device(), dtype=torch.int64
833+
)
834+
linear_cache_indices_merged = torch.tensor(
835+
[1, 2, 3], device=torch.cuda.current_device(), dtype=torch.int64
836+
)
837+
838+
# This should not crash even though runtime_meta is None
839+
tbe._store_prefetched_tensors(
840+
indices=indices,
841+
offsets=offsets,
842+
vbe_metadata=None,
843+
linear_cache_indices_merged=linear_cache_indices_merged,
844+
final_lxu_cache_locations=torch.ones_like(linear_cache_indices_merged),
845+
hash_zch_identities=None,
846+
hash_zch_runtime_meta=None,
847+
)
848+
849+
self.assertEqual(len(tbe.prefetched_info_list), 1)
850+
self.assertIsNone(tbe.prefetched_info_list[0].hash_zch_runtime_meta)
851+
630852

631853
if __name__ == "__main__":
632854
unittest.main()

0 commit comments

Comments
 (0)