Skip to content

Commit ce091bf

Browse files
committed
feat(cms): add Count-Min Sketch data structure
Implement CMS (Count-Min Sketch) probabilistic data structure with the following commands: - CMS.INITBYDIM key width depth - CMS.INITBYPROB key error probability - CMS.INCRBY key item increment [...] - CMS.QUERY key item [...] - CMS.MERGE dest numkeys src... [WEIGHTS ...] - CMS.INFO key Storage design: - Metadata stored in Metadata CF - Count Matrix stored in PrimarySubkey CF (per-bucket storage) - Uses MurmurHash64 for layer-specific hashing Related: #2425
1 parent 96dd8cb commit ce091bf

6 files changed

Lines changed: 886 additions & 2 deletions

File tree

src/commands/cmd_cms.cc

Lines changed: 272 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,272 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*
19+
*/
20+
21+
#include <types/redis_cms.h>
22+
23+
#include "commander.h"
24+
#include "commands/command_parser.h"
25+
#include "server/redis_reply.h"
26+
#include "server/server.h"
27+
28+
namespace redis {
29+
30+
/// CMS.INITBYDIM key width depth
31+
/// Initialize a Count-Min Sketch with specified dimensions.
32+
/// Complexity: O(width * depth) to initialize all buckets.
33+
class CommandCMSInitByDim final : public Commander {
34+
public:
35+
Status Parse(const std::vector<std::string> &args) override {
36+
auto parse_width = ParseInt<uint32_t>(args[2], 10);
37+
if (!parse_width) {
38+
return {Status::RedisParseErr, "invalid width"};
39+
}
40+
width_ = *parse_width;
41+
42+
auto parse_depth = ParseInt<uint32_t>(args[3], 10);
43+
if (!parse_depth) {
44+
return {Status::RedisParseErr, "invalid depth"};
45+
}
46+
depth_ = *parse_depth;
47+
48+
return Commander::Parse(args);
49+
}
50+
51+
Status Execute(engine::Context &ctx, Server *srv, Connection *conn, std::string *output) override {
52+
redis::CMS cms(srv->storage, conn->GetNamespace());
53+
54+
auto s = cms.InitByDim(ctx, args_[1], width_, depth_);
55+
if (!s.ok()) return {Status::RedisExecErr, s.ToString()};
56+
57+
*output = redis::RESP_OK;
58+
return Status::OK();
59+
}
60+
61+
private:
62+
uint32_t width_;
63+
uint32_t depth_;
64+
};
65+
66+
/// CMS.INITBYPROB key error probability
67+
/// Initialize a Count-Min Sketch with specified error rate and probability.
68+
/// Complexity: O(width * depth) where width/depth are calculated from error/probability.
69+
class CommandCMSInitByProb final : public Commander {
70+
public:
71+
Status Parse(const std::vector<std::string> &args) override {
72+
auto parse_error = ParseFloat<double>(args[2]);
73+
if (!parse_error) {
74+
return {Status::RedisParseErr, "invalid error rate"};
75+
}
76+
error_rate_ = *parse_error;
77+
78+
auto parse_prob = ParseFloat<double>(args[3]);
79+
if (!parse_prob) {
80+
return {Status::RedisParseErr, "invalid probability"};
81+
}
82+
probability_ = *parse_prob;
83+
84+
return Commander::Parse(args);
85+
}
86+
87+
Status Execute(engine::Context &ctx, Server *srv, Connection *conn, std::string *output) override {
88+
redis::CMS cms(srv->storage, conn->GetNamespace());
89+
90+
auto s = cms.InitByProb(ctx, args_[1], error_rate_, probability_);
91+
if (!s.ok()) return {Status::RedisExecErr, s.ToString()};
92+
93+
*output = redis::RESP_OK;
94+
return Status::OK();
95+
}
96+
97+
private:
98+
double error_rate_;
99+
double probability_;
100+
};
101+
102+
/// CMS.INCRBY key item increment [item increment ...]
103+
/// Increment the count of one or more items.
104+
/// Complexity: O(depth) for each item.
105+
/// Returns: Array of estimated counts for each item after increment.
106+
class CommandCMSIncrBy final : public Commander {
107+
public:
108+
Status Parse(const std::vector<std::string> &args) override {
109+
if (args.size() < 4 || (args.size() - 2) % 2 != 0) {
110+
return {Status::RedisParseErr, "wrong number of arguments"};
111+
}
112+
113+
for (size_t i = 2; i < args.size(); i += 2) {
114+
auto parse_increment = ParseInt<int64_t>(args[i + 1], 10);
115+
if (!parse_increment) {
116+
return {Status::RedisParseErr, "invalid increment"};
117+
}
118+
if (*parse_increment < 0) {
119+
return {Status::RedisParseErr, "increment must be non-negative"};
120+
}
121+
items_.emplace_back(args[i], *parse_increment);
122+
}
123+
124+
return Commander::Parse(args);
125+
}
126+
127+
Status Execute(engine::Context &ctx, Server *srv, Connection *conn, std::string *output) override {
128+
redis::CMS cms(srv->storage, conn->GetNamespace());
129+
std::vector<uint64_t> counts;
130+
131+
auto s = cms.IncrBy(ctx, args_[1], items_, &counts);
132+
if (!s.ok()) return {Status::RedisExecErr, s.ToString()};
133+
134+
*output = redis::MultiLen(counts.size());
135+
for (auto count : counts) {
136+
*output += redis::Integer(count);
137+
}
138+
return Status::OK();
139+
}
140+
141+
private:
142+
std::vector<std::pair<std::string, int64_t>> items_;
143+
};
144+
145+
/// CMS.QUERY key item [item ...]
146+
/// Return the estimated count of one or more items.
147+
/// Complexity: O(depth) for each item.
148+
/// Returns: Array of estimated counts for each item.
149+
class CommandCMSQuery final : public Commander {
150+
public:
151+
Status Parse(const std::vector<std::string> &args) override {
152+
items_.reserve(args.size() - 2);
153+
for (size_t i = 2; i < args.size(); ++i) {
154+
items_.push_back(args[i]);
155+
}
156+
return Commander::Parse(args);
157+
}
158+
159+
Status Execute(engine::Context &ctx, Server *srv, Connection *conn, std::string *output) override {
160+
redis::CMS cms(srv->storage, conn->GetNamespace());
161+
std::vector<uint64_t> counts;
162+
163+
auto s = cms.Query(ctx, args_[1], items_, &counts);
164+
if (!s.ok()) return {Status::RedisExecErr, s.ToString()};
165+
166+
*output = redis::MultiLen(counts.size());
167+
for (auto count : counts) {
168+
*output += redis::Integer(count);
169+
}
170+
return Status::OK();
171+
}
172+
173+
private:
174+
std::vector<std::string> items_;
175+
};
176+
177+
/// CMS.MERGE destkey numkeys sourcekey [sourcekey ...] [WEIGHTS weight [weight ...]]
178+
/// Merge multiple Count-Min Sketches into one.
179+
/// Complexity: O(width * depth * numkeys).
180+
///
181+
/// Kvrocks extension:
182+
/// The underlying implementation supports SUM/MAX/MIN merge methods via CMSMergeMethod enum.
183+
/// Currently only SUM is exposed for Redis compatibility. Future versions may add:
184+
/// [METHOD SUM|MAX|MIN] parameter to allow specifying merge strategy.
185+
/// This makes Kvrocks a superset of Redis CMS functionality.
186+
class CommandCMSMerge final : public Commander {
187+
public:
188+
Status Parse(const std::vector<std::string> &args) override {
189+
auto parse_numkeys = ParseInt<size_t>(args[2], 10);
190+
if (!parse_numkeys) {
191+
return {Status::RedisParseErr, "invalid numkeys"};
192+
}
193+
numkeys_ = *parse_numkeys;
194+
195+
if (args.size() < 3 + numkeys_) {
196+
return {Status::RedisParseErr, "wrong number of arguments"};
197+
}
198+
199+
// Parse source keys
200+
for (size_t i = 0; i < numkeys_; ++i) {
201+
src_keys_.push_back(args[3 + i]);
202+
}
203+
204+
// Parse optional WEIGHTS
205+
size_t next_arg = 3 + numkeys_;
206+
if (next_arg < args.size() && strcasecmp(args[next_arg].c_str(), "WEIGHTS") == 0) {
207+
next_arg++;
208+
if (args.size() < next_arg + numkeys_) {
209+
return {Status::RedisParseErr, "wrong number of weights"};
210+
}
211+
for (size_t i = 0; i < numkeys_; ++i) {
212+
auto parse_weight = ParseInt<uint64_t>(args[next_arg + i], 10);
213+
if (!parse_weight) {
214+
return {Status::RedisParseErr, "invalid weight"};
215+
}
216+
weights_.push_back(*parse_weight);
217+
}
218+
}
219+
220+
return Commander::Parse(args);
221+
}
222+
223+
Status Execute(engine::Context &ctx, Server *srv, Connection *conn, std::string *output) override {
224+
redis::CMS cms(srv->storage, conn->GetNamespace());
225+
226+
auto s = cms.Merge(ctx, args_[1], src_keys_, weights_, CMSMergeMethod::SUM);
227+
if (!s.ok()) return {Status::RedisExecErr, s.ToString()};
228+
229+
*output = redis::RESP_OK;
230+
return Status::OK();
231+
}
232+
233+
private:
234+
size_t numkeys_;
235+
std::vector<std::string> src_keys_;
236+
std::vector<uint64_t> weights_;
237+
};
238+
239+
/// CMS.INFO key
240+
/// Return information about a Count-Min Sketch.
241+
/// Returns: Array of key-value pairs (width, depth, total count, size).
242+
class CommandCMSInfo final : public Commander {
243+
public:
244+
Status Execute(engine::Context &ctx, Server *srv, Connection *conn, std::string *output) override {
245+
redis::CMS cms(srv->storage, conn->GetNamespace());
246+
CMSInfo info;
247+
248+
auto s = cms.Info(ctx, args_[1], &info);
249+
if (s.IsNotFound()) return {Status::RedisExecErr, "key not found"};
250+
if (!s.ok()) return {Status::RedisExecErr, s.ToString()};
251+
252+
*output = redis::MultiLen(8);
253+
*output += redis::SimpleString("width");
254+
*output += redis::Integer(info.width);
255+
*output += redis::SimpleString("depth");
256+
*output += redis::Integer(info.depth);
257+
*output += redis::SimpleString("count");
258+
*output += redis::Integer(info.total_count);
259+
*output += redis::SimpleString("size");
260+
*output += redis::Integer(info.size);
261+
return Status::OK();
262+
}
263+
};
264+
265+
REDIS_REGISTER_COMMANDS(CMS, MakeCmdAttr<CommandCMSInitByDim>("cms.initbydim", 4, "write", 1, 1, 1),
266+
MakeCmdAttr<CommandCMSInitByProb>("cms.initbyprob", 4, "write", 1, 1, 1),
267+
MakeCmdAttr<CommandCMSIncrBy>("cms.incrby", -4, "write", 1, 1, 1),
268+
MakeCmdAttr<CommandCMSQuery>("cms.query", -3, "read-only", 1, 1, 1),
269+
MakeCmdAttr<CommandCMSMerge>("cms.merge", -4, "write", 1, 1, 1),
270+
MakeCmdAttr<CommandCMSInfo>("cms.info", 2, "read-only", 1, 1, 1), )
271+
272+
} // namespace redis

src/commands/commander.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ enum class CommandCategory : uint8_t {
9595
Unknown = 0,
9696
Bit,
9797
BloomFilter,
98+
CMS,
9899
Cluster,
99100
Function,
100101
Geo,

src/storage/redis_metadata.cc

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ bool Metadata::IsSingleKVType() const { return Type() == kRedisString || Type()
334334

335335
bool Metadata::IsEmptyableType() const {
336336
return IsSingleKVType() || Type() == kRedisStream || Type() == kRedisBloomFilter || Type() == kRedisHyperLogLog ||
337-
Type() == kRedisTDigest || Type() == kRedisTimeSeries;
337+
Type() == kRedisTDigest || Type() == kRedisTimeSeries || Type() == kRedisCMS;
338338
}
339339

340340
bool Metadata::Expired() const { return ExpireAt(util::GetTimeStampMS()); }
@@ -569,3 +569,32 @@ rocksdb::Status TimeSeriesMetadata::Decode(Slice *input) {
569569

570570
return rocksdb::Status::OK();
571571
}
572+
573+
void CMSMetadata::Encode(std::string *dst) const {
574+
Metadata::Encode(dst);
575+
576+
PutFixed32(dst, width);
577+
PutFixed32(dst, depth);
578+
PutFixed64(dst, total_count);
579+
PutFixed8(dst, static_cast<uint8_t>(storage_mode));
580+
}
581+
582+
rocksdb::Status CMSMetadata::Decode(Slice *input) {
583+
if (auto s = Metadata::Decode(input); !s.ok()) {
584+
return s;
585+
}
586+
587+
if (input->size() < 4 + 4 + 8 + 1) {
588+
return rocksdb::Status::InvalidArgument(kErrMetadataTooShort);
589+
}
590+
591+
GetFixed32(input, &width);
592+
GetFixed32(input, &depth);
593+
GetFixed64(input, &total_count);
594+
595+
uint8_t mode = 0;
596+
GetFixed8(input, &mode);
597+
storage_mode = static_cast<StorageMode>(mode);
598+
599+
return rocksdb::Status::OK();
600+
}

src/storage/redis_metadata.h

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,13 @@ enum RedisType : uint8_t {
5454
kRedisHyperLogLog = 11,
5555
kRedisTDigest = 12,
5656
kRedisTimeSeries = 13,
57+
kRedisCMS = 14,
5758
kRedisTypeMax
5859
};
5960

6061
inline constexpr const std::array<std::string_view, kRedisTypeMax> RedisTypeNames = {
6162
"none", "string", "hash", "list", "set", "zset", "bitmap",
62-
"sortedint", "stream", "MBbloom--", "ReJSON-RL", "hyperloglog", "TDIS-TYPE", "timeseries"};
63+
"sortedint", "stream", "MBbloom--", "ReJSON-RL", "hyperloglog", "TDIS-TYPE", "timeseries", "cms"};
6364

6465
struct RedisTypes {
6566
RedisTypes(std::initializer_list<RedisType> list) {
@@ -409,3 +410,32 @@ class TimeSeriesMetadata : public Metadata {
409410
void Encode(std::string *dst) const override;
410411
rocksdb::Status Decode(Slice *input) override;
411412
};
413+
414+
class CMSMetadata : public Metadata {
415+
public:
416+
enum class StorageMode : uint8_t {
417+
PER_BUCKET = 0, // 按桶存储(默认)
418+
SINGLE_KEY = 1, // 单 Key 存储
419+
};
420+
421+
/// Width of the count matrix (number of buckets per layer)
422+
uint32_t width;
423+
424+
/// Depth of the count matrix (number of layers)
425+
uint32_t depth;
426+
427+
/// Total count of all INCRBY operations
428+
uint64_t total_count;
429+
430+
/// Storage mode
431+
StorageMode storage_mode;
432+
433+
explicit CMSMetadata(bool generate_version = true)
434+
: Metadata(kRedisCMS, generate_version), width(0), depth(0), total_count(0), storage_mode(StorageMode::PER_BUCKET) {}
435+
436+
CMSMetadata(uint32_t width, uint32_t depth, bool generate_version = true)
437+
: Metadata(kRedisCMS, generate_version), width(width), depth(depth), total_count(0), storage_mode(StorageMode::PER_BUCKET) {}
438+
439+
void Encode(std::string *dst) const override;
440+
rocksdb::Status Decode(Slice *input) override;
441+
};

0 commit comments

Comments
 (0)