Skip to content

Commit fa1da14

Browse files
Lizhe Jifacebook-github-bot
authored andcommitted
Add SSD backend integration and metadata API to DramKVEmbeddingCache
Summary: X-link: facebookresearch/FBGEMM#2847 Integrates `DramKVEmbeddingCache` with an SSD backend by exposing metadata retrieval APIs and internal state accessors. This enables the SSD tier to track dirty memory blocks for flushing, manage cross-tier feature eviction, and allows the enrichment process to skip redundant external data source fetches for IDs already present in SSD. Differential Revision: D108959007
1 parent 0b95945 commit fa1da14

2 files changed

Lines changed: 321 additions & 0 deletions

File tree

fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache.h

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,10 @@
2323
#include <thrift/lib/cpp2/protocol/Serializer.h>
2424
#include <torch/script.h>
2525
#include <cmath>
26+
#include <cstring>
2627
#include <random>
2728
#include <string_view>
29+
#include <unordered_set>
2830
#include "common/time/Time.h"
2931

3032
#include "../ssd_split_embeddings_cache/initializer.h"
@@ -369,6 +371,62 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
369371
return metadata_tensor;
370372
}
371373

374+
at::Tensor get_kv_metadata_rows(
375+
const at::Tensor& indices,
376+
const at::Tensor& count) {
377+
auto numel = indices.size(0);
378+
const int64_t metadata_dim =
379+
static_cast<int64_t>(FixedBlockPool::get_metaheader_dim<weight_type>());
380+
auto metadata_tensor = at::zeros(
381+
{numel, metadata_dim},
382+
at::TensorOptions().dtype(
383+
c10::CppTypeToScalarType<weight_type>::value));
384+
auto shardid_to_indexes = shard_input(indices, count);
385+
std::vector<folly::Future<folly::Unit>> futures;
386+
futures.reserve(shardid_to_indexes.size());
387+
const size_t metadata_bytes = metadata_dim * sizeof(weight_type);
388+
for (const auto& [shard_id, indexes] : shardid_to_indexes) {
389+
futures.emplace_back(
390+
folly::via(executor_.get())
391+
.thenValue([this,
392+
shard_id,
393+
indexes,
394+
&indices,
395+
&metadata_tensor,
396+
metadata_bytes](folly::Unit) {
397+
FBGEMM_DISPATCH_INTEGRAL_TYPES(
398+
indices.scalar_type(),
399+
"dram_kv_metadata_rows",
400+
[this,
401+
shard_id,
402+
indexes,
403+
&indices,
404+
&metadata_tensor,
405+
metadata_bytes] {
406+
using index_t = scalar_t;
407+
CHECK(indices.is_contiguous());
408+
auto* idx_ptr = indices.const_data_ptr<index_t>();
409+
auto* md_ptr =
410+
metadata_tensor
411+
.template mutable_data_ptr<weight_type>();
412+
const int64_t md_stride = metadata_tensor.size(1);
413+
auto rlmap = kv_store_.by(shard_id).rlock();
414+
for (const auto& id_index : indexes) {
415+
auto id = int64_t(idx_ptr[id_index]);
416+
auto it = rlmap->find(id);
417+
CHECK(it != rlmap->end());
418+
std::memcpy(
419+
md_ptr + id_index * md_stride,
420+
reinterpret_cast<const char*>(it->second),
421+
metadata_bytes);
422+
}
423+
});
424+
}));
425+
}
426+
folly::collect(futures).wait();
427+
return metadata_tensor;
428+
}
429+
372430
/// insert embeddings into kvstore.
373431
/// current underlying memory management is done through F14FastMap
374432
/// key value pair will be sharded into multiple shards to increase
@@ -488,6 +546,10 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
488546
weights_data_ptr + id_index * stride,
489547
weights_data_ptr + (id_index + 1) * stride,
490548
data_ptr);
549+
// TODO: skip FixedBlockPool set_dirty here. This
550+
// DRAM_SSD embedding cache path only handles
551+
// backfill, where data already exists in SSD, so
552+
// marking dirty would trigger a redundant flush.
491553
local_write_cache_copy_total_duration +=
492554
facebook::WallClockUtil::NowInUsecFast() -
493555
before_copy_ts;
@@ -635,6 +697,9 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
635697
weights_data_ptr + id_index * stride,
636698
weights_data_ptr + (id_index + 1) * stride,
637699
data_ptr);
700+
if (enable_ssd_backend_) {
701+
pool->set_dirty(block, true);
702+
}
638703
cursor++;
639704
// Check if we should pause and yield lock
640705
if (is_laser_write_interrupted()) {
@@ -735,6 +800,12 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
735800
weights_data_ptr + tensor_offset * stride,
736801
weights_data_ptr + (tensor_offset + 1) * stride,
737802
data_ptr);
803+
804+
// TODO: skip FixedBlockPool set_dirty here. This
805+
// DRAM_SSD embedding cache path only handles
806+
// backfill, where data already exists in SSD, so
807+
// marking dirty would trigger a redundant flush.
808+
738809
// update provided ts for existing blocks
739810
if (feature_evict_config_.has_value() &&
740811
feature_evict_config_.value()->trigger_mode_ !=
@@ -765,6 +836,11 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
765836
weights_data_ptr + (tensor_offset + 1) * stride,
766837
data_ptr);
767838

839+
// TODO: skip FixedBlockPool set_dirty here. This
840+
// DRAM_SSD embedding cache path only handles
841+
// backfill, where data already exists in SSD, so
842+
// marking dirty would trigger a redundant flush.
843+
768844
// update provided ts for new allocated blocks
769845
if (feature_evict_config_.has_value() &&
770846
feature_evict_config_.value()->trigger_mode_ !=
@@ -1781,6 +1857,10 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
17811857

17821858
void compact() override {}
17831859

1860+
void set_ssd_backend() {
1861+
enable_ssd_backend_ = true;
1862+
}
1863+
17841864
void trigger_feature_evict() {
17851865
if (feature_evict_) {
17861866
feature_evict_->trigger_evict();
@@ -1945,6 +2025,25 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
19452025
backend_return_whole_row_ = backend_return_whole_row;
19462026
}
19472027

2028+
/// Get the feature evict object for callback wiring.
2029+
/// Returns nullptr if feature eviction is disabled.
2030+
FeatureEvict<weight_type>* get_feature_evict() {
2031+
return feature_evict_.get();
2032+
}
2033+
2034+
/// Access the internal kv_store for flush iteration.
2035+
auto& get_kv_store() {
2036+
return kv_store_;
2037+
}
2038+
2039+
int64_t get_num_shards() const {
2040+
return num_shards_;
2041+
}
2042+
2043+
int64_t get_block_size() const {
2044+
return block_size_;
2045+
}
2046+
19482047
private:
19492048
int64_t get_dim_from_index(int64_t weight_idx) const {
19502049
if (sub_table_dims_.empty()) {
@@ -2378,6 +2477,9 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
23782477
weights_data_ptr + id_index * stride,
23792478
weights_data_ptr + (id_index + 1) * stride,
23802479
block);
2480+
if (enable_ssd_backend_) {
2481+
pool->set_dirty(block, true);
2482+
}
23812483

23822484
if (new_block) {
23832485
if (feature_evict_config_.has_value() &&
@@ -2501,6 +2603,11 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
25012603

25022604
// OpenTab/Maple reader for ONEFLOW_OPENTAB_SID enrichment (type-erased)
25032605
oneflow_enrichment::ReaderPtr open_tab_reader_;
2606+
2607+
// Optional SSD backend for existence checks during enrichment.
2608+
// When set, enrichment will skip IDs that already exist in SSD,
2609+
// avoiding unnecessary calls to external data sources.
2610+
std::atomic<bool> enable_ssd_backend_{false};
25042611
}; // class DramKVEmbeddingCache
25052612

25062613
} // namespace kv_mem
Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "deeplearning/fbgemm/fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache.h"
10+
11+
#include <fmt/format.h>
12+
#include <glog/logging.h>
13+
#include <gtest/gtest.h>
14+
#include <vector>
15+
16+
namespace kv_mem {
17+
18+
struct MetaHeader {
19+
int64_t key;
20+
uint32_t timestamp;
21+
uint32_t count : 31;
22+
bool used : 1;
23+
};
24+
25+
class DramKVEmbeddingCacheTest : public ::testing::Test {
26+
protected:
27+
static constexpr int EMBEDDING_DIM = 16;
28+
static constexpr int NUM_SHARDS = 4;
29+
30+
void SetUp() override {
31+
FLAGS_logtostderr = true;
32+
FLAGS_minloglevel = 0;
33+
34+
auto hash_size_cumsum = at::tensor({0, 100000}, at::kLong);
35+
36+
dram_cache_ = std::make_shared<DramKVEmbeddingCache<float>>(
37+
EMBEDDING_DIM,
38+
/*uniform_init_lower=*/-0.1,
39+
/*uniform_init_upper=*/0.1,
40+
/*feature_evict_config=*/std::nullopt,
41+
NUM_SHARDS,
42+
/*num_threads=*/4,
43+
/*row_storage_bitwidth=*/32,
44+
/*backend_return_whole_row=*/false,
45+
/*enable_async_update=*/false,
46+
/*table_dims=*/std::nullopt,
47+
hash_size_cumsum,
48+
/*is_training=*/false,
49+
/*disable_random_init=*/true);
50+
}
51+
52+
void TearDown() override {
53+
dram_cache_.reset();
54+
}
55+
56+
void insertEmbedding(int64_t id, float value = 1.0f) {
57+
auto indices = at::tensor({id}, at::kLong);
58+
std::vector<float> emb(EMBEDDING_DIM, value);
59+
auto weights = at::from_blob(
60+
emb.data(), {1, EMBEDDING_DIM}, at::TensorOptions().dtype(at::kFloat));
61+
auto count = at::tensor({1}, at::kLong);
62+
folly::coro::blockingWait(
63+
dram_cache_->set_kv_db_async(indices, weights.clone(), count));
64+
}
65+
66+
void insertEmbeddings(const std::vector<int64_t>& ids, float value = 1.0f) {
67+
auto num = static_cast<int64_t>(ids.size());
68+
auto indices = at::tensor(ids, at::kLong);
69+
auto weights = at::full(
70+
{num, EMBEDDING_DIM}, value, at::TensorOptions().dtype(at::kFloat));
71+
auto count = at::tensor({num}, at::kLong);
72+
folly::coro::blockingWait(
73+
dram_cache_->set_kv_db_async(indices, weights, count));
74+
}
75+
76+
std::shared_ptr<DramKVEmbeddingCache<float>> dram_cache_;
77+
};
78+
79+
// Test: get_kv_metadata_rows returns correct shape and key for single inserted
80+
// id
81+
TEST_F(DramKVEmbeddingCacheTest, SingleKeyMetadata) {
82+
const int64_t test_id = 42;
83+
insertEmbedding(test_id, 2.5f);
84+
85+
auto indices = at::tensor({test_id}, at::kLong);
86+
auto count = at::tensor({1}, at::kLong);
87+
auto metadata = dram_cache_->get_kv_metadata_rows(indices, count);
88+
89+
const int64_t expected_dim =
90+
static_cast<int64_t>(FixedBlockPool::get_metaheader_dim<float>());
91+
EXPECT_EQ(metadata.dim(), 2);
92+
EXPECT_EQ(metadata.size(0), 1);
93+
EXPECT_EQ(metadata.size(1), expected_dim);
94+
EXPECT_EQ(metadata.dtype(), at::kFloat);
95+
static_assert(sizeof(MetaHeader) == 16, "MetaHeader must be 16 bytes");
96+
97+
MetaHeader header{};
98+
std::memcpy(&header, metadata.data_ptr<float>(), sizeof(MetaHeader));
99+
100+
EXPECT_EQ(header.key, test_id);
101+
EXPECT_TRUE(header.used);
102+
EXPECT_GT(header.timestamp, 0u);
103+
// count may be 0 initially or updated depending on implementation
104+
EXPECT_GE(header.count, 0u);
105+
}
106+
107+
// Test: get_kv_metadata_rows returns correct metadata for multiple keys across
108+
// shards
109+
TEST_F(DramKVEmbeddingCacheTest, MultipleKeysMetadata) {
110+
std::vector<int64_t> keys = {1, 2, 3, 10, 100, 1000};
111+
insertEmbeddings(keys, 1.0f);
112+
113+
auto indices = at::tensor(keys, at::kLong);
114+
auto count = at::tensor({static_cast<int64_t>(keys.size())}, at::kLong);
115+
auto metadata = dram_cache_->get_kv_metadata_rows(indices, count);
116+
117+
const int64_t expected_dim =
118+
static_cast<int64_t>(FixedBlockPool::get_metaheader_dim<float>());
119+
EXPECT_EQ(
120+
metadata.sizes(),
121+
at::IntArrayRef({static_cast<int64_t>(keys.size()), expected_dim}));
122+
123+
auto* md_ptr = metadata.data_ptr<float>();
124+
const int64_t stride = expected_dim;
125+
for (size_t i = 0; i < keys.size(); ++i) {
126+
MetaHeader header{};
127+
std::memcpy(&header, md_ptr + i * stride, sizeof(MetaHeader));
128+
EXPECT_EQ(header.key, keys[i]) << "Mismatch at index " << i;
129+
EXPECT_TRUE(header.used) << "Used flag false for key " << keys[i];
130+
EXPECT_GT(header.timestamp, 0u) << "Timestamp not set for key " << keys[i];
131+
}
132+
}
133+
134+
// Test: get_kv_metadata_rows with empty input returns empty tensor with correct
135+
// dim
136+
TEST_F(DramKVEmbeddingCacheTest, EmptyInputReturnsEmpty) {
137+
auto indices = at::empty({0}, at::kLong);
138+
auto count = at::tensor({0}, at::kLong);
139+
auto metadata = dram_cache_->get_kv_metadata_rows(indices, count);
140+
141+
const int64_t expected_dim =
142+
static_cast<int64_t>(FixedBlockPool::get_metaheader_dim<float>());
143+
EXPECT_EQ(metadata.dim(), 2);
144+
EXPECT_EQ(metadata.size(0), 0);
145+
EXPECT_EQ(metadata.size(1), expected_dim);
146+
}
147+
148+
// Test: get_kv_metadata_rows reflects updated timestamp after re-insert
149+
TEST_F(DramKVEmbeddingCacheTest, TimestampUpdatesOnReinsert) {
150+
const int64_t test_id = 7;
151+
insertEmbedding(test_id, 1.0f);
152+
153+
auto indices = at::tensor({test_id}, at::kLong);
154+
auto count = at::tensor({1}, at::kLong);
155+
auto metadata1 = dram_cache_->get_kv_metadata_rows(indices, count);
156+
MetaHeader h1{};
157+
std::memcpy(&h1, metadata1.data_ptr<float>(), sizeof(MetaHeader));
158+
159+
// Sleep to ensure timestamp advances (timestamp is in seconds)
160+
std::this_thread::sleep_for(std::chrono::seconds(2));
161+
162+
// Re-insert same key to update timestamp
163+
insertEmbedding(test_id, 3.0f);
164+
auto metadata2 = dram_cache_->get_kv_metadata_rows(indices, count);
165+
MetaHeader h2{};
166+
std::memcpy(&h2, metadata2.data_ptr<float>(), sizeof(MetaHeader));
167+
168+
EXPECT_EQ(h2.key, test_id);
169+
EXPECT_TRUE(h2.used);
170+
EXPECT_GE(h2.timestamp, h1.timestamp);
171+
}
172+
173+
// Test: get_kv_metadata_rows works with float16 weight type via separate cache
174+
// instance
175+
TEST_F(DramKVEmbeddingCacheTest, HalfPrecisionMetadataDim) {
176+
auto hash_size_cumsum = at::tensor({0, 100000}, at::kLong);
177+
auto dram_cache_half = std::make_shared<DramKVEmbeddingCache<at::Half>>(
178+
EMBEDDING_DIM,
179+
-0.1,
180+
0.1,
181+
std::nullopt,
182+
NUM_SHARDS,
183+
4,
184+
16,
185+
false,
186+
false,
187+
std::nullopt,
188+
hash_size_cumsum,
189+
false,
190+
true);
191+
192+
// Insert one key
193+
auto indices = at::tensor({5}, at::kLong);
194+
auto weights =
195+
at::full({1, EMBEDDING_DIM}, 1.0, at::TensorOptions().dtype(at::kHalf));
196+
auto count = at::tensor({1}, at::kLong);
197+
folly::coro::blockingWait(
198+
dram_cache_half->set_kv_db_async(indices, weights, count));
199+
200+
auto metadata = dram_cache_half->get_kv_metadata_rows(indices, count);
201+
const int64_t expected_dim =
202+
static_cast<int64_t>(FixedBlockPool::get_metaheader_dim<at::Half>());
203+
// 16 bytes / 2 bytes per half = 8
204+
EXPECT_EQ(expected_dim, 8);
205+
EXPECT_EQ(metadata.sizes(), at::IntArrayRef({1, expected_dim}));
206+
EXPECT_EQ(metadata.dtype(), at::kHalf);
207+
208+
// Decode first 8 bytes as int64 key from half tensor raw bytes
209+
int64_t decoded_key = 0;
210+
std::memcpy(&decoded_key, metadata.data_ptr<at::Half>(), sizeof(int64_t));
211+
EXPECT_EQ(decoded_key, 5);
212+
}
213+
214+
} // namespace kv_mem

0 commit comments

Comments
 (0)