Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,409 changes: 775 additions & 634 deletions fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,14 @@ class DramKVEmbeddingCacheWrapper : public torch::jit::CustomClassHolder {
return impl_->is_evicting();
}

int64_t get_bg_thread_error_count() const {
return impl_->get_bg_thread_error_count();
}

std::string get_bg_thread_error_message() const {
return impl_->get_bg_thread_error_message();
}

void set_feature_score_metadata_cuda(
at::Tensor indices,
at::Tensor count,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,14 @@ class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder {
return impl_->is_evicting();
}

int64_t get_bg_thread_error_count() const {
return impl_->get_bg_thread_error_count();
}

std::string get_bg_thread_error_message() const {
return impl_->get_bg_thread_error_message();
}

private:
friend class KVTensorWrapper;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,15 @@

namespace kv_db {

// Queue depth threshold for warning logs. When the background write queue
// exceeds this depth, a rate-limited warning is emitted to indicate the
// consumer thread may be falling behind or dead.
// 1000 is chosen as a conservative threshold: typical queue depth during
// healthy operation is <10. At 1000 items, each holding a tensor copy of
// indices + weights, host DDR consumption from the queue alone approaches
// several GB, signaling a clear anomaly worth alerting on.
constexpr int64_t kQueueDepthWarningThreshold = 1000;

namespace {

/// Read a scalar value from a tensor that is maybe a UVM tensor
Expand Down Expand Up @@ -132,8 +141,39 @@ EmbeddingKVDB::EmbeddingKVDB(
auto& count = filling_item_ptr->count;
auto& rocksdb_wmode = filling_item_ptr->mode;

update_cache_and_storage(indices, weights, count, rocksdb_wmode);
try {
update_cache_and_storage(indices, weights, count, rocksdb_wmode);
} catch (const std::exception& e) {
bg_thread_error_count_++;
{
// Store only the first error message for diagnostics.
// Subsequent errors are still logged via XLOG(ERR) below,
// but only the root cause (first failure) is surfaced via
// get_bg_thread_error_message() to avoid masking it.
std::lock_guard<std::mutex> lock(bg_error_mutex_);
if (bg_thread_error_message_.empty()) {
bg_thread_error_message_ = e.what();
}
}
XLOG(ERR)
<< "[SSD Offloading] Background worker thread caught exception: "
<< e.what()
<< ". Queue item dequeued to prevent infinite retry.";
} catch (...) {
bg_thread_error_count_++;
{
std::lock_guard<std::mutex> lock(bg_error_mutex_);
if (bg_thread_error_message_.empty()) {
bg_thread_error_message_ = "unknown non-std::exception";
}
}
XLOG(ERR)
<< "[SSD Offloading] Background worker thread caught unknown "
<< "non-std::exception. Queue item dequeued.";
}

// Dequeue even on failure to prevent infinite retry and host OOM.
// The write is lost — bg_thread_error_count_ tracks these events.
weights_to_fill_queue_.dequeue();
}
// Queue drained — notify waiters (e.g. wait_util_filling_work_done).
Expand Down Expand Up @@ -444,6 +484,14 @@ void EmbeddingKVDB::set(
auto tensor_copy_start_ts = facebook::WallClockUtil::NowInUsecFast();
auto new_item = tensor_copy(indices, weights, count, write_mode);
weights_to_fill_queue_.enqueue(new_item);
auto cur_depth = static_cast<int64_t>(weights_to_fill_queue_.size());
if (cur_depth > kQueueDepthWarningThreshold) {
XLOG_EVERY_MS(WARNING, 30000)
<< "[SSD Offloading] Background write queue depth is " << cur_depth
<< " (>" << kQueueDepthWarningThreshold
<< "). Background thread may be falling behind or dead. "
<< "Risk of host DDR OOM.";
}
// Lock barrier ensures the background thread's CV wait has completed
// its atomic unlock-and-block before we notify, preventing lost wakeups.
{
Expand Down Expand Up @@ -506,6 +554,14 @@ void EmbeddingKVDB::get(
auto new_item = tensor_copy(
indices, weights, count, kv_db::RocksdbWriteMode::FWD_ROCKSDB_READ);
weights_to_fill_queue_.enqueue(new_item);
auto cur_depth = static_cast<int64_t>(weights_to_fill_queue_.size());
if (cur_depth > kQueueDepthWarningThreshold) {
XLOG_EVERY_MS(WARNING, 30000)
<< "[SSD Offloading] Background write queue depth is "
<< cur_depth << " (>" << kQueueDepthWarningThreshold
<< "). Background thread may be falling behind or dead. "
<< "Risk of host DDR OOM.";
}
// Lock barrier ensures the background thread's CV wait has completed
// its atomic unlock-and-block before we notify, preventing lost
// wakeups.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,17 @@ class EmbeddingKVDB : public std::enable_shared_from_this<EmbeddingKVDB> {
FBEXCEPTION("Not implemented");
}

/// Get background thread error count. Non-zero means writes were dropped.
int64_t get_bg_thread_error_count() const {
return bg_thread_error_count_.load(std::memory_order_relaxed);
}

/// Get first background thread error message (empty if no errors).
std::string get_bg_thread_error_message() const {
std::lock_guard<std::mutex> lock(bg_error_mutex_);
return bg_thread_error_message_;
}

virtual void trigger_feature_evict();

virtual bool is_evicting();
Expand Down Expand Up @@ -561,6 +572,11 @@ class EmbeddingKVDB : public std::enable_shared_from_this<EmbeddingKVDB> {
// -- commone path
std::atomic<int64_t> total_cache_update_duration_{0};

// -- background thread error tracking
std::atomic<int64_t> bg_thread_error_count_{0};
mutable std::mutex bg_error_mutex_;
std::string bg_thread_error_message_;

protected:
std::unique_ptr<fbgemm_gpu::RawEmbeddingStreamer> raw_embedding_streamer_;
}; // class EmbeddingKVDB
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1058,7 +1058,13 @@ static auto embedding_rocks_db_wrapper =
&EmbeddingRocksDBWrapper::create_rocksdb_hard_link_snapshot)
.def(
"get_active_checkpoint_uuid",
&EmbeddingRocksDBWrapper::get_active_checkpoint_uuid);
&EmbeddingRocksDBWrapper::get_active_checkpoint_uuid)
.def(
"get_bg_thread_error_count",
&EmbeddingRocksDBWrapper::get_bg_thread_error_count)
.def(
"get_bg_thread_error_message",
&EmbeddingRocksDBWrapper::get_bg_thread_error_message);

auto enrichment_config =
torch::class_<kv_mem::EnrichmentConfig>("fbgemm", "EnrichmentConfig")
Expand Down Expand Up @@ -1246,9 +1252,13 @@ static auto dram_kv_embedding_cache_wrapper =
.def(
"get_feature_evict_metric",
&DramKVEmbeddingCacheWrapper::get_feature_evict_metric)
.def("get_dram_kv_perf", &DramKVEmbeddingCacheWrapper::get_dram_kv_perf)
.def(
"get_bg_thread_error_count",
&DramKVEmbeddingCacheWrapper::get_bg_thread_error_count)
.def(
"get_dram_kv_perf",
&DramKVEmbeddingCacheWrapper::get_dram_kv_perf);
"get_bg_thread_error_message",
&DramKVEmbeddingCacheWrapper::get_bg_thread_error_message);
static auto embedding_rocks_db_read_only_wrapper =
torch::class_<ReadOnlyEmbeddingKVDB>("fbgemm", "ReadOnlyEmbeddingKVDB")
.def(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,13 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB {
db = db_ptr.release();
#endif
}
if (!s.ok()) {
XLOG(ERR)
<< "[SSD Offloading] RocksDB Open FAILED for shard path '"
<< shard_path << "': " << s.ToString()
<< ". Possible causes: disk full, wrong permissions, corrupted DB, "
<< "incompatible options.";
}
CHECK(s.ok()) << s.ToString();
dbs_.emplace_back(db);
}
Expand Down Expand Up @@ -597,6 +604,13 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB {
D * sizeof(value_t)));
}
auto s = dbs_[shard]->Write(wo_, &batch);
if (!s.ok()) {
XLOG(ERR)
<< "[SSD Offloading] RocksDB WriteBatch FAILED on shard "
<< shard << ": " << s.ToString()
<< ". Process will crash (CHECK). "
<< "Possible causes: disk full, I/O error, hardware failure.";
}
CHECK(s.ok())
<< "Failed to write batch to db, error: "
<< s.ToString();
Expand Down Expand Up @@ -966,14 +980,23 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB {

void compact() override {
for (auto& db : dbs_) {
db->CompactRange(rocksdb::CompactRangeOptions(), nullptr, nullptr);
auto s =
db->CompactRange(rocksdb::CompactRangeOptions(), nullptr, nullptr);
if (!s.ok()) {
XLOG_EVERY_MS(ERR, 60000)
<< "[SSD Offloading] CompactRange failed: " << s.ToString();
}
}
}

void flush() {
kv_db::EmbeddingKVDB::flush();
for (auto& db : dbs_) {
db->Flush(rocksdb::FlushOptions());
auto s = db->Flush(rocksdb::FlushOptions());
if (!s.ok()) {
XLOG_EVERY_MS(ERR, 60000)
<< "[SSD Offloading] Flush failed: " << s.ToString();
}
}
}

Expand Down Expand Up @@ -1006,7 +1029,12 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB {
rocksdb::FlushOptions fo;
fo.wait = false;
fo.allow_write_stall = false;
dbs_[i]->Flush(fo);
auto s = dbs_[i]->Flush(fo);
if (!s.ok()) {
XLOG_EVERY_MS(ERR, 60000)
<< "[SSD Offloading] Staggered flush failed on shard " << i
<< ": " << s.ToString();
}
if (i == dbs_.size() - 1) {
done_staggered_flushes_ = true;
int64_t period_per_shard = compaction_period_ / dbs_.size();
Expand All @@ -1027,8 +1055,13 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB {
dbs_[i]->GetColumnFamilyMetaData(&meta);
int32_t num_level0 = meta.levels[0].files.size();
if (num_level0 >= l0_files_per_compact_) {
dbs_[i]->CompactRange(
auto compact_status = dbs_[i]->CompactRange(
rocksdb::CompactRangeOptions(), nullptr, nullptr);
if (!compact_status.ok()) {
XLOG_EVERY_MS(ERR, 60000)
<< "[SSD Offloading] Manual compaction failed on shard " << i
<< ": " << compact_status.ToString();
}
}
shard_flush_compaction_deadlines_[i] += compaction_period_;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,100 @@ TEST(SSDTableBatchedEmbeddingsTest, TestToggleCompactionFailOnThronw) {
{ mock_embedding_rocks->toggle_compaction(true); },
"Failed to toggle compaction to 1 with exception std::runtime_error: some error message");
}

// Note: This test verifies the counter is initialized correctly but does not
// exercise the actual try/catch path because getMockEmbeddingRocksDB uses
// enable_async_update=false by default. The try/catch is validated by the
// Python-side CrashingStatsReporter tests and by production monitoring.
TEST(SSDTableBatchedEmbeddingsTest, TestBackgroundThreadErrorCountInit) {
int num_shards = 1;
auto db = getMockEmbeddingRocksDB(num_shards, "bg_error_count");

// bg_thread_error_count_ should be initialized to 0
EXPECT_EQ(db->get_bg_thread_error_count(), 0);

// Do some normal operations to verify the counter stays at 0
auto indices = at::arange(0, 5, at::TensorOptions().dtype(at::kLong));
auto weights = at::randn(
{5, EMBEDDING_DIMENSION}, at::TensorOptions().dtype(at::kFloat));
auto count = at::tensor({5}, at::ScalarType::Long);
db->set_kv_to_storage(indices, weights);
db->wait_util_filling_work_done();

// After normal operations, error count should still be 0
EXPECT_EQ(db->get_bg_thread_error_count(), 0);
}

TEST(SSDTableBatchedEmbeddingsTest, TestFlushAndCompactWithoutCrash) {
int num_shards = 2;
auto db = getMockEmbeddingRocksDB(num_shards, "flush_compact_test");

// Write some data
auto indices = at::arange(0, 10, at::TensorOptions().dtype(at::kLong));
auto weights = at::randn(
{10, EMBEDDING_DIMENSION}, at::TensorOptions().dtype(at::kFloat));
db->set_kv_to_storage(indices, weights);

// Flush should succeed without crash (new code checks return value and logs)
db->flush();

// Compact should succeed without crash (new code checks return value and
// logs)
db->compact();

// Verify DB is still healthy by checking mem usage
auto mem_usage = db->get_mem_usage();
EXPECT_GT(mem_usage.size(), 0);
}

// Note: This test exercises the queue tracking code path but does not reach
// the >1000 warning threshold. The warning path is validated by log inspection
// in production monitoring.
TEST(SSDTableBatchedEmbeddingsTest, TestQueueDepthWarningPath) {
int num_shards = 1;
auto db = getMockEmbeddingRocksDB(num_shards, "queue_depth");

// Do multiple writes to exercise the queue depth tracking code path
for (int i = 0; i < 5; i++) {
auto indices =
at::arange(i * 10, i * 10 + 10, at::TensorOptions().dtype(at::kLong));
auto weights = at::randn(
{10, EMBEDDING_DIMENSION}, at::TensorOptions().dtype(at::kFloat));
db->set_kv_to_storage(indices, weights);
}

// Wait for all operations to complete
db->wait_util_filling_work_done();

// After draining, bg error count should still be 0
EXPECT_EQ(db->get_bg_thread_error_count(), 0);
}

TEST(SSDTableBatchedEmbeddingsTest, TestCompactionAfterMultipleFlushes) {
int num_shards = 2;
auto db = getMockEmbeddingRocksDB(num_shards, "compaction_after_flushes");

// Write data and flush multiple times to create SST files
for (int batch = 0; batch < 3; batch++) {
auto indices = at::arange(
batch * 100, batch * 100 + 100, at::TensorOptions().dtype(at::kLong));
auto weights = at::randn(
{100, EMBEDDING_DIMENSION}, at::TensorOptions().dtype(at::kFloat));
db->set_kv_to_storage(indices, weights);
// Flush with new return value checking should not crash
db->flush();
}

// Compact with new return value checking should not crash
db->compact();

// Verify DB is still functional after compaction by reading back data
auto read_indices = at::arange(0, 10, at::TensorOptions().dtype(at::kLong));
auto read_weights = at::zeros(
{10, EMBEDDING_DIMENSION}, at::TensorOptions().dtype(at::kFloat));
auto read_count = at::tensor({10}, at::ScalarType::Long);
db->get_kv_db_async(read_indices, read_weights, read_count).wait();

// Verify we got non-zero weights back (data was written)
EXPECT_GT(read_weights.abs().sum().item<float>(), 0);
}
Loading
Loading