diff --git a/.gitignore b/.gitignore index e090a2c4c19..00a8b69bbfd 100644 --- a/.gitignore +++ b/.gitignore @@ -52,3 +52,6 @@ /build/ /cmake-build-*/ /build-*/ + +# Go module cache +/pkg/ diff --git a/src/commands/cmd_cms.cc b/src/commands/cmd_cms.cc new file mode 100644 index 00000000000..121f050caf2 --- /dev/null +++ b/src/commands/cmd_cms.cc @@ -0,0 +1,337 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#include + +#include "commander.h" +#include "commands/command_parser.h" +#include "server/redis_reply.h" +#include "server/server.h" + +namespace redis { + +/// CMS.INITBYDIM - Initialize a Count-Min Sketch with specified dimensions +/// +/// Redis command: CMS.INITBYDIM key width depth +/// Documentation: https://redis.io/docs/latest/commands/cms.initbydim/ +/// +/// Parameters: +/// - key: The name of the sketch +/// - width: Number of counters in each array (reduces error size) +/// - depth: Number of counter-arrays (reduces error probability) +/// +/// Time complexity: O(1) +/// ACL categories: @cms, @write, @fast +class CommandCMSInitByDim final : public Commander { + public: + Status Parse(const std::vector &args) override { + auto parse_width = ParseInt(args[2], 10); + if (!parse_width) { + return {Status::RedisParseErr, "invalid width"}; + } + width_ = *parse_width; + + auto parse_depth = ParseInt(args[3], 10); + if (!parse_depth) { + return {Status::RedisParseErr, "invalid depth"}; + } + depth_ = *parse_depth; + + return Commander::Parse(args); + } + + Status Execute(engine::Context &ctx, Server *srv, Connection *conn, std::string *output) override { + redis::CMS cms(srv->storage, conn->GetNamespace()); + + auto s = cms.InitByDim(ctx, args_[1], width_, depth_); + if (!s.ok()) return {Status::RedisExecErr, s.ToString()}; + + *output = redis::RESP_OK; + return Status::OK(); + } + + private: + uint32_t width_; + uint32_t depth_; +}; + +/// CMS.INITBYPROB - Initialize a Count-Min Sketch with specified error rate and probability +/// +/// Redis command: CMS.INITBYPROB key error probability +/// Documentation: https://redis.io/docs/latest/commands/cms.initbyprob/ +/// +/// Parameters: +/// - key: The name of the sketch +/// - error: Estimate size of error (as percent of total counted items) +/// - probability: Desired probability for inflated count (failure probability) +/// +/// Time complexity: O(1) +/// ACL categories: @cms, @write, @fast +class CommandCMSInitByProb final : public Commander { + public: + Status Parse(const std::vector &args) override { + auto parse_error = ParseFloat(args[2]); + if (!parse_error) { + return {Status::RedisParseErr, "invalid error rate"}; + } + error_rate_ = *parse_error; + + auto parse_prob = ParseFloat(args[3]); + if (!parse_prob) { + return {Status::RedisParseErr, "invalid probability"}; + } + probability_ = *parse_prob; + + return Commander::Parse(args); + } + + Status Execute(engine::Context &ctx, Server *srv, Connection *conn, std::string *output) override { + redis::CMS cms(srv->storage, conn->GetNamespace()); + + auto s = cms.InitByProb(ctx, args_[1], error_rate_, probability_); + if (!s.ok()) return {Status::RedisExecErr, s.ToString()}; + + *output = redis::RESP_OK; + return Status::OK(); + } + + private: + double error_rate_; + double probability_; +}; + +/// CMS.INCRBY - Increment the count of one or more items +/// +/// Redis command: CMS.INCRBY key item increment [item increment ...] +/// Documentation: https://redis.io/docs/latest/commands/cms.incrby/ +/// +/// Parameters: +/// - key: The name of the sketch +/// - item: The item to increment +/// - increment: Amount to increment (must be non-negative) +/// +/// Time complexity: O(n) where n is the number of items +/// ACL categories: @cms, @write, @fast +/// +/// Returns: Array of estimated counts for each item after increment +/// Errors: invalid arguments, missing key, overflow (saturates at UINT32_MAX), wrong key type +class CommandCMSIncrBy final : public Commander { + public: + Status Parse(const std::vector &args) override { + if (args.size() < 4 || (args.size() - 2) % 2 != 0) { + return {Status::RedisParseErr, "wrong number of arguments"}; + } + + for (size_t i = 2; i < args.size(); i += 2) { + auto parse_increment = ParseInt(args[i + 1], 10); + if (!parse_increment) { + return {Status::RedisParseErr, "invalid increment"}; + } + if (*parse_increment < 0) { + return {Status::RedisParseErr, "increment must be non-negative"}; + } + items_.emplace_back(args[i], *parse_increment); + } + + return Commander::Parse(args); + } + + Status Execute(engine::Context &ctx, Server *srv, Connection *conn, std::string *output) override { + redis::CMS cms(srv->storage, conn->GetNamespace()); + std::vector counts; + + auto s = cms.IncrBy(ctx, args_[1], items_, &counts); + if (!s.ok()) return {Status::RedisExecErr, s.ToString()}; + + *output = redis::MultiLen(counts.size()); + for (auto count : counts) { + *output += redis::Integer(count); + } + return Status::OK(); + } + + private: + std::vector> items_; +}; + +/// CMS.QUERY - Return the estimated count of one or more items +/// +/// Redis command: CMS.QUERY key item [item ...] +/// Documentation: https://redis.io/docs/latest/commands/cms.query/ +/// +/// Parameters: +/// - key: The name of the sketch +/// - item: One or more items to query +/// +/// Time complexity: O(n) where n is the number of items +/// ACL categories: @cms, @read, @fast +/// +/// Returns: Array of estimated counts (min-counts across all layers) +/// Errors: invalid arguments, missing key, wrong key type +class CommandCMSQuery final : public Commander { + public: + Status Parse(const std::vector &args) override { + items_.reserve(args.size() - 2); + for (size_t i = 2; i < args.size(); ++i) { + items_.push_back(args[i]); + } + return Commander::Parse(args); + } + + Status Execute(engine::Context &ctx, Server *srv, Connection *conn, std::string *output) override { + redis::CMS cms(srv->storage, conn->GetNamespace()); + std::vector counts; + + auto s = cms.Query(ctx, args_[1], items_, &counts); + if (!s.ok()) return {Status::RedisExecErr, s.ToString()}; + + *output = redis::MultiLen(counts.size()); + for (auto count : counts) { + *output += redis::Integer(count); + } + return Status::OK(); + } + + private: + std::vector items_; +}; + +/// CMS.MERGE - Merge multiple Count-Min Sketches into one +/// +/// Redis command: CMS.MERGE destination numKeys source [source ...] [WEIGHTS weight [weight ...]] +/// Documentation: https://redis.io/docs/latest/commands/cms.merge/ +/// +/// Parameters: +/// - destination: Name of destination sketch (must be initialized) +/// - numKeys: Number of sketches to merge +/// - source: Names of source sketches +/// - weight: Multiplier for each sketch (can be negative, default = 1) +/// +/// Time complexity: O(n) where n is the number of sketches +/// ACL categories: @cms, @write +/// +/// Requirements: +/// - All sketches must have identical width and depth +/// - Destination must already exist +/// +/// Returns: OK on success +/// Errors: invalid arguments, overflow, dimension mismatch, missing key +class CommandCMSMerge final : public Commander { + public: + Status Parse(const std::vector &args) override { + auto parse_numkeys = ParseInt(args[2], 10); + if (!parse_numkeys) { + return {Status::RedisParseErr, "invalid numkeys"}; + } + numkeys_ = *parse_numkeys; + + if (args.size() < 3 + numkeys_) { + return {Status::RedisParseErr, "wrong number of arguments"}; + } + + // Parse source keys + for (size_t i = 0; i < numkeys_; ++i) { + src_keys_.push_back(args[3 + i]); + } + + // Parse optional WEIGHTS + size_t next_arg = 3 + numkeys_; + if (next_arg < args.size() && strcasecmp(args[next_arg].c_str(), "WEIGHTS") == 0) { + next_arg++; + if (args.size() < next_arg + numkeys_) { + return {Status::RedisParseErr, "wrong number of weights"}; + } + for (size_t i = 0; i < numkeys_; ++i) { + auto parse_weight = ParseInt(args[next_arg + i], 10); + if (!parse_weight) { + return {Status::RedisParseErr, "invalid weight"}; + } + weights_.push_back(*parse_weight); + } + } + + return Commander::Parse(args); + } + + Status Execute(engine::Context &ctx, Server *srv, Connection *conn, std::string *output) override { + redis::CMS cms(srv->storage, conn->GetNamespace()); + + auto s = cms.Merge(ctx, args_[1], src_keys_, weights_); + if (!s.ok()) return {Status::RedisExecErr, s.ToString()}; + + *output = redis::RESP_OK; + return Status::OK(); + } + + private: + size_t numkeys_; + std::vector src_keys_; + std::vector weights_; +}; + +/// CMS.INFO - Return information about a Count-Min Sketch +/// +/// Redis command: CMS.INFO key +/// Documentation: https://redis.io/docs/latest/commands/cms.info/ +/// +/// Parameters: +/// - key: The name of the sketch +/// +/// Time complexity: O(1) +/// ACL categories: @cms, @read, @fast +/// +/// Returns: Array of key-value pairs: +/// - width: Number of counters per layer +/// - depth: Number of layers +/// - count: Total count of all items +/// - size: Total number of buckets (Kvrocks extension) +/// +/// Errors: missing key, wrong key type +class CommandCMSInfo final : public Commander { + public: + Status Execute(engine::Context &ctx, Server *srv, Connection *conn, std::string *output) override { + redis::CMS cms(srv->storage, conn->GetNamespace()); + CMSInfo info; + + auto s = cms.Info(ctx, args_[1], &info); + if (s.IsNotFound()) return {Status::RedisExecErr, "key not found"}; + if (!s.ok()) return {Status::RedisExecErr, s.ToString()}; + + *output = redis::MultiLen(8); + *output += redis::SimpleString("width"); + *output += redis::Integer(info.width); + *output += redis::SimpleString("depth"); + *output += redis::Integer(info.depth); + *output += redis::SimpleString("count"); + *output += redis::Integer(info.total_count); + *output += redis::SimpleString("size"); + *output += redis::Integer(info.size); + return Status::OK(); + } +}; + +REDIS_REGISTER_COMMANDS(CMS, MakeCmdAttr("cms.initbydim", 4, "write", 1, 1, 1), + MakeCmdAttr("cms.initbyprob", 4, "write", 1, 1, 1), + MakeCmdAttr("cms.incrby", -4, "write", 1, 1, 1), + MakeCmdAttr("cms.query", -3, "read-only", 1, 1, 1), + MakeCmdAttr("cms.merge", -4, "write", 1, 1, 1), + MakeCmdAttr("cms.info", 2, "read-only", 1, 1, 1), ) + +} // namespace redis \ No newline at end of file diff --git a/src/commands/commander.h b/src/commands/commander.h index 3f38db02580..612774e5fa0 100644 --- a/src/commands/commander.h +++ b/src/commands/commander.h @@ -95,6 +95,7 @@ enum class CommandCategory : uint8_t { Unknown = 0, Bit, BloomFilter, + CMS, Cluster, Function, Geo, diff --git a/src/storage/redis_metadata.cc b/src/storage/redis_metadata.cc index 692f8804db1..99e6f09fe92 100644 --- a/src/storage/redis_metadata.cc +++ b/src/storage/redis_metadata.cc @@ -334,7 +334,7 @@ bool Metadata::IsSingleKVType() const { return Type() == kRedisString || Type() bool Metadata::IsEmptyableType() const { return IsSingleKVType() || Type() == kRedisStream || Type() == kRedisBloomFilter || Type() == kRedisHyperLogLog || - Type() == kRedisTDigest || Type() == kRedisTimeSeries; + Type() == kRedisTDigest || Type() == kRedisTimeSeries || Type() == kRedisCMS; } bool Metadata::Expired() const { return ExpireAt(util::GetTimeStampMS()); } @@ -569,3 +569,32 @@ rocksdb::Status TimeSeriesMetadata::Decode(Slice *input) { return rocksdb::Status::OK(); } + +void CMSMetadata::Encode(std::string *dst) const { + Metadata::Encode(dst); + + PutFixed32(dst, width); + PutFixed32(dst, depth); + PutFixed64(dst, total_count); + PutFixed8(dst, static_cast(storage_mode)); +} + +rocksdb::Status CMSMetadata::Decode(Slice *input) { + if (auto s = Metadata::Decode(input); !s.ok()) { + return s; + } + + if (input->size() < 4 + 4 + 8 + 1) { + return rocksdb::Status::InvalidArgument(kErrMetadataTooShort); + } + + GetFixed32(input, &width); + GetFixed32(input, &depth); + GetFixed64(input, &total_count); + + uint8_t mode = 0; + GetFixed8(input, &mode); + storage_mode = static_cast(mode); + + return rocksdb::Status::OK(); +} diff --git a/src/storage/redis_metadata.h b/src/storage/redis_metadata.h index fd80e5a5ba0..376dcc60e43 100644 --- a/src/storage/redis_metadata.h +++ b/src/storage/redis_metadata.h @@ -54,12 +54,13 @@ enum RedisType : uint8_t { kRedisHyperLogLog = 11, kRedisTDigest = 12, kRedisTimeSeries = 13, + kRedisCMS = 14, kRedisTypeMax }; inline constexpr const std::array RedisTypeNames = { "none", "string", "hash", "list", "set", "zset", "bitmap", - "sortedint", "stream", "MBbloom--", "ReJSON-RL", "hyperloglog", "TDIS-TYPE", "timeseries"}; + "sortedint", "stream", "MBbloom--", "ReJSON-RL", "hyperloglog", "TDIS-TYPE", "timeseries", "cms"}; struct RedisTypes { RedisTypes(std::initializer_list list) { @@ -409,3 +410,32 @@ class TimeSeriesMetadata : public Metadata { void Encode(std::string *dst) const override; rocksdb::Status Decode(Slice *input) override; }; + +class CMSMetadata : public Metadata { + public: + enum class StorageMode : uint8_t { + PER_BUCKET = 0, // 按桶存储(默认) + SINGLE_KEY = 1, // 单 Key 存储 + }; + + /// Width of the count matrix (number of buckets per layer) + uint32_t width; + + /// Depth of the count matrix (number of layers) + uint32_t depth; + + /// Total count of all INCRBY operations + uint64_t total_count; + + /// Storage mode + StorageMode storage_mode; + + explicit CMSMetadata(bool generate_version = true) + : Metadata(kRedisCMS, generate_version), width(0), depth(0), total_count(0), storage_mode(StorageMode::PER_BUCKET) {} + + CMSMetadata(uint32_t width, uint32_t depth, bool generate_version = true) + : Metadata(kRedisCMS, generate_version), width(width), depth(depth), total_count(0), storage_mode(StorageMode::PER_BUCKET) {} + + void Encode(std::string *dst) const override; + rocksdb::Status Decode(Slice *input) override; +}; diff --git a/src/types/redis_cms.cc b/src/types/redis_cms.cc new file mode 100644 index 00000000000..08bfb889a78 --- /dev/null +++ b/src/types/redis_cms.cc @@ -0,0 +1,542 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#include "redis_cms.h" + +#include +#include + +#include "vendor/murmurhash2.h" + +namespace redis { + +// ============================================================================ +// Private helper methods +// ============================================================================ + +rocksdb::Status CMS::getCMSMetadata(engine::Context &ctx, const Slice &ns_key, CMSMetadata *metadata) { + return Database::GetMetadata(ctx, {kRedisCMS}, ns_key, metadata); +} + +std::string CMS::getBucketKey(const Slice &ns_key, uint64_t version, uint32_t bucket_id) { + std::string sub_key; + PutFixed32(&sub_key, bucket_id); + return InternalKey(ns_key, sub_key, version, storage_->IsSlotIdEncoded()).Encode(); +} + +uint64_t CMS::hashItem(const Slice &item, uint32_t layer) { + // Use MurmurHash64 with layer as seed + // This ensures different hash functions for each layer + return HllMurMurHash64A(item.data(), static_cast(item.size()), layer); +} + +uint32_t CMS::getCol(const Slice &item, uint32_t layer, uint32_t width) { + uint64_t hash = hashItem(item, layer); + return static_cast(hash % width); +} + +std::pair CMS::calcDimFromProb(double error_rate, double probability) { + // Formula from RedisBloom implementation: + // width = ceil(2 / error_rate) + // depth = ceil(log10(probability) / log10(0.5)) + // + // Note: RedisBloom uses simplified approximations of the CMS paper formulas. + // This differs from the theoretical formulas (width = e/error, depth = ln(1/probability)). + // + // Per Redis documentation (https://redis.io/docs/latest/commands/cms.initbyprob/): + // - error: Estimate size of error, as a percent of total counted items + // - probability: The desired probability for inflated count (i.e., the probability + // that the estimate exceeds the true count by more than error_rate). + auto width = static_cast(std::ceil(2.0 / error_rate)); + auto depth = static_cast(std::ceil(std::log10(probability) / std::log10(0.5))); + + // Clamp to allowed range + width = std::max(1u, std::min(width, kCMSMaxWidth)); + depth = std::max(1u, std::min(depth, kCMSMaxDepth)); + + return {width, depth}; +} + +// ============================================================================ +// Public API methods +// ============================================================================ + +/// CMS.INITBYDIM - Initialize CMS with given dimensions +/// +/// Redis command: CMS.INITBYDIM key width depth +/// Documentation: https://redis.io/docs/latest/commands/cms.initbydim/ +/// +/// Parameters: +/// - key: The name of the sketch +/// - width: Number of counters in each array. Reduces the error size. +/// - depth: Number of counter-arrays. Reduces the probability for an error. +/// +/// Time complexity: O(1) +/// Returns: OK on success +/// +/// Note: Buckets are lazily initialized (not pre-allocated). Missing buckets +/// are treated as 0, reducing write amplification on CMS creation. +rocksdb::Status CMS::InitByDim(engine::Context &ctx, const Slice &key, uint32_t width, uint32_t depth) { + if (width == 0) { + return rocksdb::Status::InvalidArgument("width must be positive"); + } + if (depth == 0) { + return rocksdb::Status::InvalidArgument("depth must be positive"); + } + if (width > kCMSMaxWidth) { + return rocksdb::Status::InvalidArgument("width exceeds maximum limit (" + std::to_string(kCMSMaxWidth) + ")"); + } + if (depth > kCMSMaxDepth) { + return rocksdb::Status::InvalidArgument("depth exceeds maximum limit (" + std::to_string(kCMSMaxDepth) + ")"); + } + + // Check total size + uint64_t total_size = static_cast(width) * depth * 4; + if (total_size > kCMSMaxSize) { + return rocksdb::Status::InvalidArgument("matrix size exceeds maximum limit (max " + + std::to_string(kCMSMaxSize / 1024 / 1024) + "MB)"); + } + + std::string ns_key = ComposeNamespaceKey(namespace_, key, storage_->IsSlotIdEncoded()); + + // Check if key already exists + CMSMetadata existing_metadata; + rocksdb::Status s = getCMSMetadata(ctx, ns_key, &existing_metadata); + if (!s.IsNotFound()) { + if (s.ok()) { + return rocksdb::Status::InvalidArgument("key already exists"); + } + return s; + } + + // Create new CMS metadata + CMSMetadata metadata; + metadata.width = width; + metadata.depth = depth; + metadata.total_count = 0; + metadata.storage_mode = CMSMetadata::StorageMode::PER_BUCKET; + + // Write metadata and initialize all buckets to 0 + auto batch = storage_->GetWriteBatchBase(); + WriteBatchLogData log_data(kRedisCMS, {"InitByDim"}); + s = batch->PutLogData(log_data.Encode()); + if (!s.ok()) return s; + + // Write metadata + std::string metadata_bytes; + metadata.Encode(&metadata_bytes); + s = batch->Put(metadata_cf_handle_, ns_key, metadata_bytes); + if (!s.ok()) return s; + + // Note: Buckets are not pre-initialized. Missing buckets are treated as 0. + // This lazy initialization reduces write amplification on CMS creation. + // Query and IncrBy correctly handle missing buckets by treating them as 0. + + return storage_->Write(ctx, storage_->DefaultWriteOptions(), batch->GetWriteBatch()); +} + +/// CMS.INITBYPROB - Initialize CMS with given error rate and probability +/// +/// Redis command: CMS.INITBYPROB key error probability +/// Documentation: https://redis.io/docs/latest/commands/cms.initbyprob/ +/// +/// Parameters: +/// - key: The name of the sketch +/// - error: Estimate size of error, as a percent of total counted items +/// - probability: The desired probability for inflated count (failure probability) +/// For example, for 0.1% failure rate, set probability = 0.001 +/// +/// Time complexity: O(1) +/// Returns: OK on success +/// +/// Formula (from RedisBloom): +/// width = ceil(2 / error) +/// depth = ceil(log10(probability) / log10(0.5)) +rocksdb::Status CMS::InitByProb(engine::Context &ctx, const Slice &key, double error_rate, double probability) { + if (error_rate <= 0 || error_rate >= 1) { + return rocksdb::Status::InvalidArgument("error rate must be between 0 and 1"); + } + if (probability <= 0 || probability >= 1) { + return rocksdb::Status::InvalidArgument("probability must be between 0 and 1"); + } + + auto [width, depth] = calcDimFromProb(error_rate, probability); + return InitByDim(ctx, key, width, depth); +} + +/// CMS.INCRBY - Increment counters for given items +/// +/// Redis command: CMS.INCRBY key item increment [item increment ...] +/// Documentation: https://redis.io/docs/latest/commands/cms.incrby/ +/// +/// Parameters: +/// - key: The name of the sketch +/// - item: The item to increment +/// - increment: Amount to increment (must be non-negative) +/// +/// Time complexity: O(depth) for each item +/// Returns: Array of estimated counts for each item after increment +/// +/// Overflow behavior: Counters saturate at UINT32_MAX (~4.3 billion). +/// This matches RedisBloom behavior (silent saturation, no error). +rocksdb::Status CMS::IncrBy(engine::Context &ctx, const Slice &key, + const std::vector> &items, std::vector *counts) { + if (items.empty()) { + return rocksdb::Status::InvalidArgument("no items provided"); + } + + std::string ns_key = ComposeNamespaceKey(namespace_, key, storage_->IsSlotIdEncoded()); + + CMSMetadata metadata; + rocksdb::Status s = getCMSMetadata(ctx, ns_key, &metadata); + if (!s.ok()) return s; + + counts->resize(items.size(), 0); + + auto batch = storage_->GetWriteBatchBase(); + WriteBatchLogData log_data(kRedisCMS, {"IncrBy"}); + s = batch->PutLogData(log_data.Encode()); + if (!s.ok()) return s; + + uint64_t total_increment = 0; + + for (size_t i = 0; i < items.size(); ++i) { + const auto &[item, increment] = items[i]; + Slice item_slice(item); + + if (increment < 0) { + return rocksdb::Status::InvalidArgument("increment must be non-negative"); + } + if (increment > UINT32_MAX) { + return rocksdb::Status::InvalidArgument("increment exceeds maximum value"); + } + + uint64_t min_count = UINT64_MAX; + std::vector> bucket_updates; // (bucket_id, new_count) + + // Read and update depth buckets (one per layer) for this item + for (uint32_t layer = 0; layer < metadata.depth; ++layer) { + uint32_t col = getCol(item_slice, layer, metadata.width); + uint32_t bucket_id = layer * metadata.width + col; + std::string bucket_key = getBucketKey(ns_key, metadata.version, bucket_id); + + // Read current count + std::string count_str; + s = storage_->Get(ctx, ctx.GetReadOptions(), bucket_key, &count_str); + uint32_t current_count = 0; + if (s.ok() && count_str.size() >= 4) { + Slice count_slice(count_str); + GetFixed32(&count_slice, ¤t_count); + } else if (!s.IsNotFound()) { + return s; + } + + // Check for overflow + auto inc_val = static_cast(increment); + uint32_t new_count = 0; + if (__builtin_add_overflow(current_count, inc_val, &new_count)) { + new_count = UINT32_MAX; // Saturate at max value + } + bucket_updates.emplace_back(bucket_id, new_count); + + if (new_count < min_count) { + min_count = new_count; + } + } + + // Write all updated buckets + for (const auto &[bucket_id, new_count] : bucket_updates) { + std::string bucket_key = getBucketKey(ns_key, metadata.version, bucket_id); + std::string count_value; + PutFixed32(&count_value, new_count); + s = batch->Put(bucket_key, count_value); + if (!s.ok()) return s; + } + + (*counts)[i] = min_count; + + // Check total_increment overflow + if (__builtin_add_overflow(total_increment, static_cast(increment), &total_increment)) { + total_increment = UINT64_MAX; // Saturate at max value + } + } + + // Update metadata: increment total_count with overflow check + uint64_t new_total = 0; + if (__builtin_add_overflow(metadata.total_count, total_increment, &new_total)) { + metadata.total_count = UINT64_MAX; + } else { + metadata.total_count = new_total; + } + std::string metadata_bytes; + metadata.Encode(&metadata_bytes); + s = batch->Put(metadata_cf_handle_, ns_key, metadata_bytes); + if (!s.ok()) return s; + + return storage_->Write(ctx, storage_->DefaultWriteOptions(), batch->GetWriteBatch()); +} + +/// CMS.QUERY - Query estimated counts for given items +/// +/// Redis command: CMS.QUERY key item [item ...] +/// Documentation: https://redis.io/docs/latest/commands/cms.query/ +/// +/// Parameters: +/// - key: The name of the sketch +/// - item: One or more items to query +/// +/// Time complexity: O(depth) for each item +/// Returns: Array of estimated counts (minimum count across all layers) +/// +/// Note: CMS never underestimates the true count (returns count >= actual). +rocksdb::Status CMS::Query(engine::Context &ctx, const Slice &key, const std::vector &items, + std::vector *counts) { + if (items.empty()) { + return rocksdb::Status::InvalidArgument("no items provided"); + } + + std::string ns_key = ComposeNamespaceKey(namespace_, key, storage_->IsSlotIdEncoded()); + + CMSMetadata metadata; + rocksdb::Status s = getCMSMetadata(ctx, ns_key, &metadata); + if (!s.ok()) return s; + + counts->resize(items.size(), 0); + + for (size_t i = 0; i < items.size(); ++i) { + Slice item_slice(items[i]); + uint64_t min_count = UINT64_MAX; + + for (uint32_t layer = 0; layer < metadata.depth; ++layer) { + uint32_t col = getCol(item_slice, layer, metadata.width); + uint32_t bucket_id = layer * metadata.width + col; + std::string bucket_key = getBucketKey(ns_key, metadata.version, bucket_id); + + std::string count_str; + s = storage_->Get(ctx, ctx.GetReadOptions(), bucket_key, &count_str); + + uint32_t count = 0; + if (s.ok() && count_str.size() >= 4) { + Slice count_slice(count_str); + GetFixed32(&count_slice, &count); + } else if (!s.IsNotFound()) { + return s; + } + + if (count < min_count) { + min_count = count; + } + } + + (*counts)[i] = min_count; + } + + return rocksdb::Status::OK(); +} + +/// CMS.MERGE - Merge multiple CMS sketches into one +/// +/// Redis command: CMS.MERGE destination numKeys source [source ...] [WEIGHTS weight [weight ...]] +/// Documentation: https://redis.io/docs/latest/commands/cms.merge/ +/// +/// Parameters: +/// - destination: The name of destination sketch (must be initialized) +/// - numKeys: Number of sketches to merge +/// - source: Names of source sketches to merge +/// - weight: Multiplier for each sketch (can be negative, default = 1) +/// +/// Time complexity: O(width * depth * numKeys) +/// Returns: OK on success, error if overflow detected +/// +/// Requirements: +/// - All sketches must have identical width and depth +/// - Destination must already exist +/// +/// Overflow behavior: Returns error if any bucket would overflow UINT32_MAX or +/// become negative after applying weights. This matches RedisBloom behavior. +rocksdb::Status CMS::Merge(engine::Context &ctx, const Slice &dest_key, const std::vector &src_keys, + const std::vector &weights) { + if (src_keys.empty()) { + return rocksdb::Status::InvalidArgument("no source keys provided"); + } + if (!weights.empty() && weights.size() != src_keys.size()) { + return rocksdb::Status::InvalidArgument("number of weights must match number of source keys"); + } + + // Get all source metadata and validate dimensions match + std::vector src_metadata(src_keys.size()); + std::vector src_ns_keys(src_keys.size()); + + for (size_t i = 0; i < src_keys.size(); ++i) { + src_ns_keys[i] = ComposeNamespaceKey(namespace_, src_keys[i], storage_->IsSlotIdEncoded()); + rocksdb::Status s = getCMSMetadata(ctx, src_ns_keys[i], &src_metadata[i]); + if (!s.ok()) { + if (s.IsNotFound()) { + return rocksdb::Status::InvalidArgument("source key not found: " + src_keys[i]); + } + return s; + } + + // Validate dimensions match + if (i > 0) { + if (src_metadata[i].width != src_metadata[0].width || src_metadata[i].depth != src_metadata[0].depth) { + return rocksdb::Status::InvalidArgument("CMS dimensions do not match"); + } + } + } + + // Create destination CMS + std::string dest_ns_key = ComposeNamespaceKey(namespace_, dest_key, storage_->IsSlotIdEncoded()); + + // Check if destination exists + // Per Redis documentation: destination must be initialized + CMSMetadata dest_metadata; + rocksdb::Status s = getCMSMetadata(ctx, dest_ns_key, &dest_metadata); + if (s.IsNotFound()) { + return rocksdb::Status::InvalidArgument("destination key not found: " + dest_key.ToString()); + } + if (!s.ok()) { + return s; + } + + // Check if destination dimensions match source + if (dest_metadata.width != src_metadata[0].width || dest_metadata.depth != src_metadata[0].depth) { + return rocksdb::Status::InvalidArgument("destination dimensions do not match source"); + } + + uint32_t width = dest_metadata.width; + uint32_t depth = dest_metadata.depth; + uint32_t total_buckets = width * depth; + + // Phase 1: Pre-check for overflow (following RedisBloom behavior) + // Read all buckets and validate no overflow will occur + std::vector> src_buckets(src_keys.size()); + for (size_t k = 0; k < src_keys.size(); ++k) { + src_buckets[k].resize(total_buckets, 0); + for (uint32_t bucket_id = 0; bucket_id < total_buckets; ++bucket_id) { + std::string bucket_key = getBucketKey(src_ns_keys[k], src_metadata[k].version, bucket_id); + std::string count_str; + s = storage_->Get(ctx, ctx.GetReadOptions(), bucket_key, &count_str); + if (s.ok() && count_str.size() >= 4) { + Slice count_slice(count_str); + GetFixed32(&count_slice, &src_buckets[k][bucket_id]); + } else if (!s.IsNotFound()) { + return s; + } + } + } + + // Check for overflow in all buckets + for (uint32_t bucket_id = 0; bucket_id < total_buckets; ++bucket_id) { + int64_t item_count = 0; + for (size_t k = 0; k < src_keys.size(); ++k) { + int64_t weight = weights.empty() ? 1 : weights[k]; + int64_t count = src_buckets[k][bucket_id]; + int64_t mul = 0; + + // Check for multiplication and addition overflow + if (__builtin_mul_overflow(count, weight, &mul) || __builtin_add_overflow(item_count, mul, &item_count)) { + return rocksdb::Status::InvalidArgument("overflow detected in merge operation"); + } + } + + // Validate result is within valid range for uint32_t + if (item_count < 0 || item_count > UINT32_MAX) { + return rocksdb::Status::InvalidArgument("overflow detected in merge operation"); + } + } + + // Check total_count overflow + int64_t cms_count = 0; + for (size_t k = 0; k < src_keys.size(); ++k) { + int64_t weight = weights.empty() ? 1 : weights[k]; + int64_t mul = 0; + if (__builtin_mul_overflow(static_cast(src_metadata[k].total_count), weight, &mul) || + __builtin_add_overflow(cms_count, mul, &cms_count)) { + return rocksdb::Status::InvalidArgument("overflow detected in merge operation"); + } + } + if (cms_count < 0) { + return rocksdb::Status::InvalidArgument("overflow detected in merge operation"); + } + + // Phase 2: Execute merge (pre-check passed, no overflow will occur) + auto batch = storage_->GetWriteBatchBase(); + WriteBatchLogData log_data(kRedisCMS, {"Merge"}); + s = batch->PutLogData(log_data.Encode()); + if (!s.ok()) return s; + + dest_metadata.total_count = 0; + + for (uint32_t bucket_id = 0; bucket_id < total_buckets; ++bucket_id) { + int64_t item_count = 0; + + for (size_t k = 0; k < src_keys.size(); ++k) { + int64_t weight = weights.empty() ? 1 : weights[k]; + item_count += static_cast(src_buckets[k][bucket_id]) * weight; + } + + // Write merged bucket + std::string dest_bucket_key = getBucketKey(dest_ns_key, dest_metadata.version, bucket_id); + std::string count_value; + PutFixed32(&count_value, static_cast(item_count)); + s = batch->Put(dest_bucket_key, count_value); + if (!s.ok()) return s; + + dest_metadata.total_count += item_count; + } + + // Write destination metadata + std::string metadata_bytes; + dest_metadata.Encode(&metadata_bytes); + s = batch->Put(metadata_cf_handle_, dest_ns_key, metadata_bytes); + if (!s.ok()) return s; + + return storage_->Write(ctx, storage_->DefaultWriteOptions(), batch->GetWriteBatch()); +} + +/// CMS.INFO - Get CMS information +/// +/// Redis command: CMS.INFO key +/// Documentation: https://redis.io/docs/latest/commands/cms.info/ +/// +/// Parameters: +/// - key: The name of the sketch +/// +/// Time complexity: O(1) +/// Returns: width, depth, count (total count), size (number of buckets) +/// +/// Note: 'size' field is a Kvrocks extension (not in Redis). +rocksdb::Status CMS::Info(engine::Context &ctx, const Slice &key, CMSInfo *info) { + std::string ns_key = ComposeNamespaceKey(namespace_, key, storage_->IsSlotIdEncoded()); + + CMSMetadata metadata; + rocksdb::Status s = getCMSMetadata(ctx, ns_key, &metadata); + if (!s.ok()) return s; + + info->width = metadata.width; + info->depth = metadata.depth; + info->total_count = metadata.total_count; + info->size = static_cast(metadata.width) * metadata.depth; + + return rocksdb::Status::OK(); +} + +} // namespace redis \ No newline at end of file diff --git a/src/types/redis_cms.h b/src/types/redis_cms.h new file mode 100644 index 00000000000..a6fb68c3ad3 --- /dev/null +++ b/src/types/redis_cms.h @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#pragma once + +#include "storage/redis_db.h" +#include "storage/redis_metadata.h" + +namespace redis { + +/// Default width for CMS (number of buckets per layer) +constexpr uint32_t kCMSDefaultWidth = 2000; + +/// Default depth for CMS (number of layers) +constexpr uint32_t kCMSDefaultDepth = 9; + +/// Maximum width allowed +constexpr uint32_t kCMSMaxWidth = 100000; + +/// Maximum depth allowed +constexpr uint32_t kCMSMaxDepth = 100; + +/// Maximum total size in bytes (16MB) +constexpr uint64_t kCMSMaxSize = 16 * 1024 * 1024; + +/// Info fields for CMS.INFO command +enum class CMSInfoField { + kAll, + kWidth, + kDepth, + kTotalCount, +}; + +/// CMS information structure +struct CMSInfo { + uint32_t width; + uint32_t depth; + uint64_t total_count; + uint64_t size; // Number of buckets (width * depth) +}; + +/// Count-Min Sketch probabilistic data structure +class CMS : public Database { + public: + explicit CMS(engine::Storage *storage, const std::string &ns) : Database(storage, ns) {} + + /// Initialize CMS with given dimensions + /// @param ctx Engine context + /// @param key User key + /// @param width Number of buckets per layer + /// @param depth Number of layers + rocksdb::Status InitByDim(engine::Context &ctx, const Slice &key, uint32_t width, uint32_t depth); + + /// Initialize CMS with given error rate and probability + /// @param ctx Engine context + /// @param key User key + /// @param error_rate Desired error rate (0 < error_rate < 1) + /// @param probability Desired probability (0 < probability < 1) + rocksdb::Status InitByProb(engine::Context &ctx, const Slice &key, double error_rate, double probability); + + /// Increment counters for given items + /// @param ctx Engine context + /// @param key User key + /// @param items Vector of (item, increment) pairs + /// @param counts Output: estimated counts for each item after increment + rocksdb::Status IncrBy(engine::Context &ctx, const Slice &key, + const std::vector> &items, + std::vector *counts); + + /// Query estimated counts for given items + /// @param ctx Engine context + /// @param key User key + /// @param items Vector of items to query + /// @param counts Output: estimated counts for each item + rocksdb::Status Query(engine::Context &ctx, const Slice &key, const std::vector &items, + std::vector *counts); + + /// Merge multiple CMS sketches into one + /// @param ctx Engine context + /// @param dest_key Destination key + /// @param src_keys Source CMS keys + /// @param weights Weights for each source (can be negative, optional, default all 1) + rocksdb::Status Merge(engine::Context &ctx, const Slice &dest_key, const std::vector &src_keys, + const std::vector &weights); + + /// Get CMS information + /// @param ctx Engine context + /// @param key User key + /// @param info Output: CMS information + rocksdb::Status Info(engine::Context &ctx, const Slice &key, CMSInfo *info); + + private: + /// Get CMS metadata from storage + rocksdb::Status getCMSMetadata(engine::Context &ctx, const Slice &ns_key, CMSMetadata *metadata); + + /// Build bucket key for given bucket_id + /// @param ns_key Namespace key + /// @param version Metadata version + /// @param bucket_id Bucket identifier (layer * width + col) + /// @return Encoded InternalKey for the bucket + std::string getBucketKey(const Slice &ns_key, uint64_t version, uint32_t bucket_id); + + /// Hash an item for a specific layer + /// @param item Item to hash + /// @param layer Layer index (0 to depth-1) + /// @return Hash value for the layer + static uint64_t hashItem(const Slice &item, uint32_t layer); + + /// Calculate bucket_id for an item at a given layer + /// @param item Item to calculate bucket for + /// @param layer Layer index + /// @param width Width of the CMS + /// @return Column index (hash % width) + static uint32_t getCol(const Slice &item, uint32_t layer, uint32_t width); + + /// Calculate width and depth from error rate and probability + /// @param error_rate Desired error rate + /// @param probability Desired probability + /// @return Pair of (width, depth) + static std::pair calcDimFromProb(double error_rate, double probability); +}; + +} // namespace redis \ No newline at end of file diff --git a/tests/cppunit/types/cms_test.cc b/tests/cppunit/types/cms_test.cc new file mode 100644 index 00000000000..d7c9bd5e59c --- /dev/null +++ b/tests/cppunit/types/cms_test.cc @@ -0,0 +1,260 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#include + +#include +#include + +#include "test_base.h" +#include "types/redis_cms.h" + +class RedisCMSTest : public TestBase { + protected: + explicit RedisCMSTest() : TestBase() { + cms_ = std::make_unique(storage_.get(), "cms_ns"); + } + ~RedisCMSTest() override = default; + + void SetUp() override { + TestBase::SetUp(); + [[maybe_unused]] auto s = cms_->Del(*ctx_, "cms"); + for (int x = 1; x <= 3; x++) { + s = cms_->Del(*ctx_, "cms" + std::to_string(x)); + } + } + + void TearDown() override { + TestBase::TearDown(); + [[maybe_unused]] auto s = cms_->Del(*ctx_, "cms"); + for (int x = 1; x <= 3; x++) { + s = cms_->Del(*ctx_, "cms" + std::to_string(x)); + } + } + + std::unique_ptr cms_; +}; + +TEST_F(RedisCMSTest, InitByDim) { + // Test basic initialization + auto s = cms_->InitByDim(*ctx_, "cms", 1000, 5); + ASSERT_TRUE(s.ok()); + + // Test get info + redis::CMSInfo info; + s = cms_->Info(*ctx_, "cms", &info); + ASSERT_TRUE(s.ok()); + EXPECT_EQ(1000, info.width); + EXPECT_EQ(5, info.depth); + EXPECT_EQ(0, info.total_count); + EXPECT_EQ(5000, info.size); + + // Test duplicate key + s = cms_->InitByDim(*ctx_, "cms", 100, 3); + EXPECT_TRUE(s.IsInvalidArgument()); +} + +TEST_F(RedisCMSTest, InitByProb) { + // Test initialization with error rate and probability + // Per Redis documentation, 'probability' is the failure probability (probability of inflated count) + // For a 1% failure rate, set probability = 0.01 + auto s = cms_->InitByProb(*ctx_, "cms", 0.01, 0.01); + ASSERT_TRUE(s.ok()); + + redis::CMSInfo info; + s = cms_->Info(*ctx_, "cms", &info); + ASSERT_TRUE(s.ok()); + // RedisBloom formula: + // width = ceil(2 / error_rate) = ceil(2 / 0.01) = 200 + // depth = ceil(log10(probability) / log10(0.5)) = ceil(log10(0.01) / log10(0.5)) = 7 + EXPECT_EQ(200, info.width); + EXPECT_EQ(7, info.depth); +} + +TEST_F(RedisCMSTest, IncrBy) { + // Initialize CMS + auto s = cms_->InitByDim(*ctx_, "cms", 100, 5); + ASSERT_TRUE(s.ok()); + + // Test single item increment + std::vector counts; + std::vector> items = {{"foo", 10}}; + s = cms_->IncrBy(*ctx_, "cms", items, &counts); + ASSERT_TRUE(s.ok()); + ASSERT_EQ(1, counts.size()); + EXPECT_EQ(10, counts[0]); + + // Test multiple items + items = {{"foo", 5}, {"bar", 20}}; + s = cms_->IncrBy(*ctx_, "cms", items, &counts); + ASSERT_TRUE(s.ok()); + ASSERT_EQ(2, counts.size()); + EXPECT_EQ(15, counts[0]); // 10 + 5 + EXPECT_EQ(20, counts[1]); + + // Check total count + redis::CMSInfo info; + s = cms_->Info(*ctx_, "cms", &info); + ASSERT_TRUE(s.ok()); + EXPECT_EQ(35, info.total_count); // 10 + 5 + 20 +} + +TEST_F(RedisCMSTest, Query) { + // Initialize CMS + auto s = cms_->InitByDim(*ctx_, "cms", 100, 5); + ASSERT_TRUE(s.ok()); + + // Query non-existent item + std::vector counts; + std::vector items = {"foo"}; + s = cms_->Query(*ctx_, "cms", items, &counts); + ASSERT_TRUE(s.ok()); + ASSERT_EQ(1, counts.size()); + EXPECT_EQ(0, counts[0]); + + // Increment and query + std::vector> incr_items = {{"foo", 10}, {"bar", 20}}; + s = cms_->IncrBy(*ctx_, "cms", incr_items, &counts); + ASSERT_TRUE(s.ok()); + + // Query multiple items + items = {"foo", "bar", "baz"}; + s = cms_->Query(*ctx_, "cms", items, &counts); + ASSERT_TRUE(s.ok()); + ASSERT_EQ(3, counts.size()); + EXPECT_EQ(10, counts[0]); + EXPECT_EQ(20, counts[1]); + EXPECT_EQ(0, counts[2]); // baz was never incremented +} + +TEST_F(RedisCMSTest, Merge) { + // Initialize three CMS with same dimensions (destination must be initialized) + auto s = cms_->InitByDim(*ctx_, "cms1", 100, 5); + ASSERT_TRUE(s.ok()); + s = cms_->InitByDim(*ctx_, "cms2", 100, 5); + ASSERT_TRUE(s.ok()); + s = cms_->InitByDim(*ctx_, "cms_merge", 100, 5); + ASSERT_TRUE(s.ok()); + + // Add data to both + std::vector counts; + std::vector> items1 = {{"item1", 100}}; + s = cms_->IncrBy(*ctx_, "cms1", items1, &counts); + ASSERT_TRUE(s.ok()); + + std::vector> items2 = {{"item1", 50}}; + s = cms_->IncrBy(*ctx_, "cms2", items2, &counts); + ASSERT_TRUE(s.ok()); + + // Merge without weights (SUM) + s = cms_->Merge(*ctx_, "cms_merge", {"cms1", "cms2"}, {}, redis::CMSMergeMethod::SUM); + ASSERT_TRUE(s.ok()); + + // Query merged CMS + std::vector query_items = {"item1"}; + s = cms_->Query(*ctx_, "cms_merge", query_items, &counts); + ASSERT_TRUE(s.ok()); + EXPECT_EQ(150, counts[0]); // 100 + 50 +} + +TEST_F(RedisCMSTest, MergeWithWeights) { + // Initialize three CMS with same dimensions (destination must be initialized) + auto s = cms_->InitByDim(*ctx_, "cms1", 100, 5); + ASSERT_TRUE(s.ok()); + s = cms_->InitByDim(*ctx_, "cms2", 100, 5); + ASSERT_TRUE(s.ok()); + s = cms_->InitByDim(*ctx_, "cms_merge", 100, 5); + ASSERT_TRUE(s.ok()); + + // Add data + std::vector counts; + std::vector> items1 = {{"item1", 100}}; + s = cms_->IncrBy(*ctx_, "cms1", items1, &counts); + ASSERT_TRUE(s.ok()); + + std::vector> items2 = {{"item1", 50}}; + s = cms_->IncrBy(*ctx_, "cms2", items2, &counts); + ASSERT_TRUE(s.ok()); + + // Merge with weights + s = cms_->Merge(*ctx_, "cms_merge", {"cms1", "cms2"}, {2, 3}, redis::CMSMergeMethod::SUM); + ASSERT_TRUE(s.ok()); + + // Query merged CMS + std::vector query_items = {"item1"}; + s = cms_->Query(*ctx_, "cms_merge", query_items, &counts); + ASSERT_TRUE(s.ok()); + EXPECT_EQ(350, counts[0]); // 100 * 2 + 50 * 3 = 350 +} + +TEST_F(RedisCMSTest, MergeInvalidDimensions) { + // Initialize CMS with different dimensions + auto s = cms_->InitByDim(*ctx_, "cms1", 100, 5); + ASSERT_TRUE(s.ok()); + s = cms_->InitByDim(*ctx_, "cms2", 200, 5); // Different width + ASSERT_TRUE(s.ok()); + + // Merge should fail + s = cms_->Merge(*ctx_, "cms_merge", {"cms1", "cms2"}, {}, redis::CMSMergeMethod::SUM); + EXPECT_TRUE(s.IsInvalidArgument()); +} + +TEST_F(RedisCMSTest, NonExistentKey) { + // Query non-existent key + std::vector counts; + std::vector items = {"foo"}; + auto s = cms_->Query(*ctx_, "nonexistent", items, &counts); + EXPECT_TRUE(s.IsNotFound()); + + // IncrBy non-existent key + std::vector> incr_items = {{"foo", 10}}; + s = cms_->IncrBy(*ctx_, "nonexistent", incr_items, &counts); + EXPECT_TRUE(s.IsNotFound()); + + // Info non-existent key + redis::CMSInfo info; + s = cms_->Info(*ctx_, "nonexistent", &info); + EXPECT_TRUE(s.IsNotFound()); +} + +TEST_F(RedisCMSTest, AccuracyTest) { + // Test CMS accuracy with known values + const uint32_t width = 1000; + const uint32_t depth = 10; + auto s = cms_->InitByDim(*ctx_, "cms", width, depth); + ASSERT_TRUE(s.ok()); + + // Add items with known counts + std::vector counts; + for (int i = 0; i < 100; i++) { + std::vector> items = {{"item" + std::to_string(i), i + 1}}; + s = cms_->IncrBy(*ctx_, "cms", items, &counts); + ASSERT_TRUE(s.ok()); + } + + // Query and verify accuracy (CMS may overestimate but never underestimate) + for (int i = 0; i < 100; i++) { + std::vector items = {"item" + std::to_string(i)}; + s = cms_->Query(*ctx_, "cms", items, &counts); + ASSERT_TRUE(s.ok()); + // CMS should return count >= actual count (never underestimate) + EXPECT_GE(counts[0], static_cast(i + 1)); + } +} \ No newline at end of file diff --git a/tests/gocase/unit/type/cms/cms_test.go b/tests/gocase/unit/type/cms/cms_test.go new file mode 100644 index 00000000000..1af467a74da --- /dev/null +++ b/tests/gocase/unit/type/cms/cms_test.go @@ -0,0 +1,476 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package cms + +import ( + "context" + "sync" + "testing" + + "github.com/apache/kvrocks/tests/gocase/util" + "github.com/stretchr/testify/require" +) + +func TestCMSCommands(t *testing.T) { + configOptions := []util.ConfigOptions{ + { + Name: "txn-context-enabled", + Options: []string{"yes", "no"}, + ConfigType: util.YesNo, + }, + } + + configsMatrix, err := util.GenerateConfigsMatrix(configOptions) + require.NoError(t, err) + + for _, configs := range configsMatrix { + testCMS(t, configs) + } +} + +func testCMS(t *testing.T, configs util.KvrocksServerConfigs) { + srv := util.StartServer(t, configs) + defer srv.Close() + ctx := context.Background() + rdb := srv.NewClient() + defer func() { require.NoError(t, rdb.Close()) }() + + key := "test_cms_key" + + t.Run("InitByDim basic test", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, key).Err()) + require.NoError(t, rdb.Do(ctx, "cms.initbydim", key, "1000", "5").Err()) + + // Check info + info := rdb.Do(ctx, "cms.info", key).Val() + infoSlice := info.([]interface{}) + require.Equal(t, "width", infoSlice[0]) + require.Equal(t, int64(1000), infoSlice[1]) + require.Equal(t, "depth", infoSlice[2]) + require.Equal(t, int64(5), infoSlice[3]) + }) + + t.Run("InitByDim duplicate key", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, key).Err()) + require.NoError(t, rdb.Do(ctx, "cms.initbydim", key, "1000", "5").Err()) + require.ErrorContains(t, rdb.Do(ctx, "cms.initbydim", key, "100", "3").Err(), "key already exists") + }) + + t.Run("InitByDim invalid arguments", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, key).Err()) + require.Error(t, rdb.Do(ctx, "cms.initbydim", key, "0", "5").Err()) + require.Error(t, rdb.Do(ctx, "cms.initbydim", key, "1000", "0").Err()) + }) + + t.Run("InitByProb basic test", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, key).Err()) + // Per Redis documentation, 'probability' is the failure probability (probability of inflated count) + // For a 1% failure rate, set probability = 0.01 + require.NoError(t, rdb.Do(ctx, "cms.initbyprob", key, "0.01", "0.01").Err()) + + // Check info - width and depth are calculated from error rate and probability + // RedisBloom formula: + // width = ceil(2 / error_rate) = ceil(2 / 0.01) = 200 + // depth = ceil(log10(probability) / log10(0.5)) = ceil(log10(0.01) / log10(0.5)) = ceil(6.64) = 7 + info := rdb.Do(ctx, "cms.info", key).Val() + infoSlice := info.([]interface{}) + require.Equal(t, "width", infoSlice[0]) + width := infoSlice[1].(int64) + require.Equal(t, int64(200), width) + depth := infoSlice[3].(int64) + require.Equal(t, int64(7), depth) + }) + + t.Run("InitByProb invalid arguments", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, key).Err()) + require.Error(t, rdb.Do(ctx, "cms.initbyprob", key, "0", "0.99").Err()) + require.Error(t, rdb.Do(ctx, "cms.initbyprob", key, "1", "0.99").Err()) + require.Error(t, rdb.Do(ctx, "cms.initbyprob", key, "0.01", "0").Err()) + require.Error(t, rdb.Do(ctx, "cms.initbyprob", key, "0.01", "1").Err()) + }) + + t.Run("Lazy initialization - query returns 0 without explicit init", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, key).Err()) + require.NoError(t, rdb.Do(ctx, "cms.initbydim", key, "1000", "5").Err()) + + // Query without any IncrBy - all buckets are missing but should return 0 + queryResult := rdb.Do(ctx, "cms.query", key, "nonexistent_item").Val() + querySlice := queryResult.([]interface{}) + require.Equal(t, int64(0), querySlice[0]) + + // Query multiple items + queryResult = rdb.Do(ctx, "cms.query", key, "item1", "item2", "item3").Val() + querySlice = queryResult.([]interface{}) + require.Equal(t, int64(0), querySlice[0]) + require.Equal(t, int64(0), querySlice[1]) + require.Equal(t, int64(0), querySlice[2]) + }) + + t.Run("Lazy initialization - incrby works after lazy init", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, key).Err()) + require.NoError(t, rdb.Do(ctx, "cms.initbydim", key, "1000", "5").Err()) + + // Increment items + result := rdb.Do(ctx, "cms.incrby", key, "foo", "10", "bar", "20").Val() + resultSlice := result.([]interface{}) + require.Equal(t, int64(10), resultSlice[0]) + require.Equal(t, int64(20), resultSlice[1]) + + // Query items + queryResult := rdb.Do(ctx, "cms.query", key, "foo", "bar", "baz").Val() + querySlice := queryResult.([]interface{}) + require.Equal(t, int64(10), querySlice[0]) + require.Equal(t, int64(20), querySlice[1]) + require.Equal(t, int64(0), querySlice[2]) // baz was never incremented + }) + + t.Run("IncrBy multiple times", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, key).Err()) + require.NoError(t, rdb.Do(ctx, "cms.initbydim", key, "1000", "5").Err()) + + require.NoError(t, rdb.Do(ctx, "cms.incrby", key, "item1", "100").Err()) + require.NoError(t, rdb.Do(ctx, "cms.incrby", key, "item1", "50").Err()) + + queryResult := rdb.Do(ctx, "cms.query", key, "item1").Val() + querySlice := queryResult.([]interface{}) + require.Equal(t, int64(150), querySlice[0]) + }) + + t.Run("IncrBy non-existent key", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "nonexistent").Err()) + require.Error(t, rdb.Do(ctx, "cms.incrby", "nonexistent", "foo", "10").Err()) + }) + + t.Run("Query non-existent key", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "nonexistent").Err()) + require.Error(t, rdb.Do(ctx, "cms.query", "nonexistent", "foo").Err()) + }) + + t.Run("Info returns correct total count", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, key).Err()) + require.NoError(t, rdb.Do(ctx, "cms.initbydim", key, "1000", "5").Err()) + require.NoError(t, rdb.Do(ctx, "cms.incrby", key, "a", "10", "b", "20", "c", "30").Err()) + + info := rdb.Do(ctx, "cms.info", key).Val() + infoSlice := info.([]interface{}) + // Find count in the result + for i := 0; i < len(infoSlice); i += 2 { + if infoSlice[i] == "count" { + require.Equal(t, int64(60), infoSlice[i+1]) // 10 + 20 + 30 + break + } + } + }) + + t.Run("Merge basic test", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "cms1", "cms2", "cms_merge").Err()) + + // Create three CMS with same dimensions (destination must be initialized) + require.NoError(t, rdb.Do(ctx, "cms.initbydim", "cms1", "1000", "5").Err()) + require.NoError(t, rdb.Do(ctx, "cms.initbydim", "cms2", "1000", "5").Err()) + require.NoError(t, rdb.Do(ctx, "cms.initbydim", "cms_merge", "1000", "5").Err()) + + // Add data + require.NoError(t, rdb.Do(ctx, "cms.incrby", "cms1", "item1", "100").Err()) + require.NoError(t, rdb.Do(ctx, "cms.incrby", "cms2", "item1", "50").Err()) + + // Merge + require.NoError(t, rdb.Do(ctx, "cms.merge", "cms_merge", "2", "cms1", "cms2").Err()) + + // Query merged result + queryResult := rdb.Do(ctx, "cms.query", "cms_merge", "item1").Val() + querySlice := queryResult.([]interface{}) + require.Equal(t, int64(150), querySlice[0]) // 100 + 50 + }) + + t.Run("Merge with weights", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "cms3", "cms4", "cms_merge_w").Err()) + + // Create three CMS with same dimensions (destination must be initialized) + require.NoError(t, rdb.Do(ctx, "cms.initbydim", "cms3", "1000", "5").Err()) + require.NoError(t, rdb.Do(ctx, "cms.initbydim", "cms4", "1000", "5").Err()) + require.NoError(t, rdb.Do(ctx, "cms.initbydim", "cms_merge_w", "1000", "5").Err()) + + // Add data + require.NoError(t, rdb.Do(ctx, "cms.incrby", "cms3", "item1", "100").Err()) + require.NoError(t, rdb.Do(ctx, "cms.incrby", "cms4", "item1", "50").Err()) + + // Merge with weights + require.NoError(t, rdb.Do(ctx, "cms.merge", "cms_merge_w", "2", "cms3", "cms4", "weights", "2", "3").Err()) + + // Query merged result + queryResult := rdb.Do(ctx, "cms.query", "cms_merge_w", "item1").Val() + querySlice := queryResult.([]interface{}) + require.Equal(t, int64(350), querySlice[0]) // 100*2 + 50*3 = 350 + }) + + t.Run("Merge invalid dimensions", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "cms5", "cms6", "cms_merge_err").Err()) + + // Create two CMS with different dimensions + require.NoError(t, rdb.Do(ctx, "cms.initbydim", "cms5", "1000", "5").Err()) + require.NoError(t, rdb.Do(ctx, "cms.initbydim", "cms6", "2000", "5").Err()) + + // Merge should fail + require.Error(t, rdb.Do(ctx, "cms.merge", "cms_merge_err", "2", "cms5", "cms6").Err()) + }) + + t.Run("Merge destination must exist", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "cms_src", "cms_dest_not_exist").Err()) + + // Create source CMS + require.NoError(t, rdb.Do(ctx, "cms.initbydim", "cms_src", "1000", "5").Err()) + + // Merge to non-existent destination should fail + require.ErrorContains(t, rdb.Do(ctx, "cms.merge", "cms_dest_not_exist", "1", "cms_src").Err(), "not found") + }) + + t.Run("Merge with negative weights", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "cms_neg1", "cms_neg2", "cms_neg_dest").Err()) + + // Create three CMS with same dimensions + require.NoError(t, rdb.Do(ctx, "cms.initbydim", "cms_neg1", "1000", "5").Err()) + require.NoError(t, rdb.Do(ctx, "cms.initbydim", "cms_neg2", "1000", "5").Err()) + require.NoError(t, rdb.Do(ctx, "cms.initbydim", "cms_neg_dest", "1000", "5").Err()) + + // Add data + require.NoError(t, rdb.Do(ctx, "cms.incrby", "cms_neg1", "item1", "100").Err()) + require.NoError(t, rdb.Do(ctx, "cms.incrby", "cms_neg2", "item1", "50").Err()) + + // Merge with negative weight (100 - 50 = 50) + require.NoError(t, rdb.Do(ctx, "cms.merge", "cms_neg_dest", "2", "cms_neg1", "cms_neg2", "weights", "1", "-1").Err()) + + // Query merged result + queryResult := rdb.Do(ctx, "cms.query", "cms_neg_dest", "item1").Val() + querySlice := queryResult.([]interface{}) + require.Equal(t, int64(50), querySlice[0]) // 100 - 50 = 50 + }) + + t.Run("Info non-existent key", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "nonexistent").Err()) + require.Error(t, rdb.Do(ctx, "cms.info", "nonexistent").Err()) + }) + + t.Run("CMS overestimates but never underestimates", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, key).Err()) + require.NoError(t, rdb.Do(ctx, "cms.initbydim", key, "1000", "10").Err()) + + // Add items with known counts + for i := 0; i < 100; i++ { + require.NoError(t, rdb.Do(ctx, "cms.incrby", key, "item"+string(rune('0'+i%10)), "1").Err()) + } + + // Query - CMS should return count >= actual (never underestimate) + for i := 0; i < 10; i++ { + queryResult := rdb.Do(ctx, "cms.query", key, "item"+string(rune('0'+i))).Val() + querySlice := queryResult.([]interface{}) + require.GreaterOrEqual(t, querySlice[0].(int64), int64(10)) // Each item was incremented 10 times + } + }) + + t.Run("Type command returns cms", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, key).Err()) + require.NoError(t, rdb.Do(ctx, "cms.initbydim", key, "1000", "5").Err()) + require.Equal(t, "cms", rdb.Type(ctx, key).Val()) + }) + + // Concurrent tests - Stress tests for race conditions + t.Run("Stress Concurrent INCRBY with independent connections", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "stress_cms_indep").Err()) + require.NoError(t, rdb.Do(ctx, "cms.initbydim", "stress_cms_indep", "10", "2").Err()) + + numGoroutines := 100 + incrementsPerGoroutine := 500 + var wg sync.WaitGroup + wg.Add(numGoroutines) + + // Each goroutine creates its own client connection + for i := 0; i < numGoroutines; i++ { + go func() { + defer wg.Done() + // Create independent connection for this goroutine + client := srv.NewClient() + defer client.Close() + + for j := 0; j < incrementsPerGoroutine; j++ { + _ = client.Do(ctx, "cms.incrby", "stress_cms_indep", "hot_item", "1").Err() + } + }() + } + + wg.Wait() + + // Verify final count + queryResult := rdb.Do(ctx, "cms.query", "stress_cms_indep", "hot_item").Val() + querySlice := queryResult.([]interface{}) + expected := int64(numGoroutines * incrementsPerGoroutine) + actual := querySlice[0].(int64) + + if actual != expected { + t.Errorf("RACE DETECTED: Expected %d, got %d (lost %d updates = %.2f%%)", + expected, actual, expected-actual, float64(expected-actual)/float64(expected)*100) + } + require.Equal(t, expected, actual, "Race condition: lost updates") + }) + + t.Run("Stress Concurrent INCRBY multiple keys", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "stress_multi_cms").Err()) + require.NoError(t, rdb.Do(ctx, "cms.initbydim", "stress_multi_cms", "20", "3").Err()) + + numGoroutines := 200 + var wg sync.WaitGroup + wg.Add(numGoroutines) + + // Each goroutine increments a random item + for i := 0; i < numGoroutines; i++ { + go func(id int) { + defer wg.Done() + for j := 0; j < 100; j++ { + item := "item" + string(rune('0'+(j%5))) + _ = rdb.Do(ctx, "cms.incrby", "stress_multi_cms", item, "1").Err() + } + }(i) + } + + wg.Wait() + + // Verify total count + info := rdb.Do(ctx, "cms.info", "stress_multi_cms").Val() + infoSlice := info.([]interface{}) + var totalCount int64 + for i := 0; i < len(infoSlice); i += 2 { + if infoSlice[i] == "count" { + totalCount = infoSlice[i+1].(int64) + break + } + } + expected := int64(numGoroutines * 100) + require.Equal(t, expected, totalCount, "Total count mismatch") + }) + + t.Run("Stress Concurrent Read/Write", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "stress_rw_cms").Err()) + require.NoError(t, rdb.Do(ctx, "cms.initbydim", "stress_rw_cms", "50", "3").Err()) + + numWriters := 50 + numReaders := 100 + writesPerWriter := 200 + + var wg sync.WaitGroup + wg.Add(numWriters + numReaders) + + // Writers - aggressive writes + for i := 0; i < numWriters; i++ { + go func(id int) { + defer wg.Done() + for j := 0; j < writesPerWriter; j++ { + item := "item" + string(rune('0'+(j%10))) + _ = rdb.Do(ctx, "cms.incrby", "stress_rw_cms", item, "1").Err() + } + }(i) + } + + // Readers - aggressive reads + for i := 0; i < numReaders; i++ { + go func() { + defer wg.Done() + for j := 0; j < 500; j++ { + _ = rdb.Do(ctx, "cms.query", "stress_rw_cms", "item0").Err() + } + }() + } + + wg.Wait() + + // Verify consistency + info := rdb.Do(ctx, "cms.info", "stress_rw_cms").Val() + require.NotNil(t, info) + }) + + t.Run("Stress Concurrent INCRBY and MERGE", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "stress_merge_src", "stress_merge_dest").Err()) + require.NoError(t, rdb.Do(ctx, "cms.initbydim", "stress_merge_src", "30", "2").Err()) + require.NoError(t, rdb.Do(ctx, "cms.initbydim", "stress_merge_dest", "30", "2").Err()) + + numWriters := 30 + numMergers := 20 + var wg sync.WaitGroup + wg.Add(numWriters + numMergers) + + // Writers to source + for i := 0; i < numWriters; i++ { + go func() { + defer wg.Done() + for j := 0; j < 100; j++ { + _ = rdb.Do(ctx, "cms.incrby", "stress_merge_src", "item1", "1").Err() + } + }() + } + + // Concurrent merges + for i := 0; i < numMergers; i++ { + go func() { + defer wg.Done() + _ = rdb.Do(ctx, "cms.merge", "stress_merge_dest", "1", "stress_merge_src").Err() + }() + } + + wg.Wait() + + // Check no corruption + info := rdb.Do(ctx, "cms.info", "stress_merge_dest").Val() + require.NotNil(t, info) + }) + + t.Run("Atomic verification - repeated runs", func(t *testing.T) { + // Run multiple times to catch intermittent races + for run := 0; run < 5; run++ { + key := "atomic_test_" + string(rune('0'+run)) + require.NoError(t, rdb.Del(ctx, key).Err()) + require.NoError(t, rdb.Do(ctx, "cms.initbydim", key, "5", "2").Err()) + + numGoroutines := 50 + incrementsPerGoroutine := 100 + var wg sync.WaitGroup + wg.Add(numGoroutines) + + for i := 0; i < numGoroutines; i++ { + go func() { + defer wg.Done() + for j := 0; j < incrementsPerGoroutine; j++ { + _ = rdb.Do(ctx, "cms.incrby", key, "x", "1").Err() + } + }() + } + + wg.Wait() + + queryResult := rdb.Do(ctx, "cms.query", key, "x").Val() + querySlice := queryResult.([]interface{}) + expected := int64(numGoroutines * incrementsPerGoroutine) + actual := querySlice[0].(int64) + + if actual != expected { + t.Errorf("Run %d: RACE! Expected %d, got %d", run, expected, actual) + } + require.Equal(t, expected, actual, "Run %d failed", run) + } + }) +}