diff --git a/.github/workflows/kvrocks.yaml b/.github/workflows/kvrocks.yaml index fda6d3c209e..049dcabf45c 100644 --- a/.github/workflows/kvrocks.yaml +++ b/.github/workflows/kvrocks.yaml @@ -326,8 +326,8 @@ jobs: curl -O https://download.redis.io/releases/redis-6.2.14.tar.gz tar -xzvf redis-6.2.14.tar.gz mkdir -p $HOME/local/bin - pushd redis-6.2.14 && BUILD_TLS=yes make -j$NPROC redis-cli && mv src/redis-cli $HOME/local/bin/ && popd - pushd redis-6.2.14 && BUILD_TLS=yes make -j$NPROC redis-server && mv src/redis-server $HOME/local/bin/ && popd + pushd redis-6.2.14 && USE_JEMALLOC=no BUILD_TLS=yes make -j$NPROC redis-cli && mv src/redis-cli $HOME/local/bin/ && popd + pushd redis-6.2.14 && USE_JEMALLOC=no BUILD_TLS=yes make -j$NPROC redis-server && mv src/redis-server $HOME/local/bin/ && popd - uses: actions/checkout@v6 with: diff --git a/src/commands/cmd_server.cc b/src/commands/cmd_server.cc index d5f0d75ca60..3b6d2022f72 100644 --- a/src/commands/cmd_server.cc +++ b/src/commands/cmd_server.cc @@ -40,6 +40,17 @@ namespace redis { +namespace { + +bool IsNamespaceCommandDisabled(Server *srv) { return srv->GetConfig()->redis_databases > 0; } + +bool IsNamespaceReadOnlyOnSlave(Server *srv) { + Config *config = srv->GetConfig(); + return config->repl_namespace_enabled && config->IsSlave(); +} + +} // namespace + class CommandAuth : public Commander { public: Status Execute([[maybe_unused]] engine::Context &ctx, Server *srv, Connection *conn, std::string *output) override { @@ -65,63 +76,113 @@ class CommandAuth : public Commander { }; class CommandNamespace : public Commander { + public: + Status Execute([[maybe_unused]] engine::Context &ctx, Server *srv, [[maybe_unused]] Connection *conn, + [[maybe_unused]] std::string *output) override { + if (IsNamespaceCommandDisabled(srv)) { + return {Status::RedisExecErr, "namespace command is not allowed when redis-databases > 0"}; + } + + return {Status::RedisExecErr, "NAMESPACE subcommand must be one of GET, SET, DEL, ADD and CURRENT"}; + } +}; + +class CommandNamespaceGet : public Commander { public: Status Execute([[maybe_unused]] engine::Context &ctx, Server *srv, Connection *conn, std::string *output) override { Config *config = srv->GetConfig(); - std::string sub_command = util::ToLower(args_[1]); - if (config->repl_namespace_enabled && config->IsSlave() && sub_command != "get") { - return {Status::RedisExecErr, "namespace is read-only for slave"}; - } - if (config->redis_databases > 0) { + if (IsNamespaceCommandDisabled(srv)) { return {Status::RedisExecErr, "namespace command is not allowed when redis-databases > 0"}; } - if (args_.size() == 3 && sub_command == "get") { - if (args_[2] == "*") { - std::vector namespaces; - auto tokens = srv->GetNamespace()->List(); - for (auto &token : tokens) { - namespaces.emplace_back(token.second); // namespace - namespaces.emplace_back(token.first); // token - } - namespaces.emplace_back(kDefaultNamespace); - namespaces.emplace_back(config->requirepass); - *output = ArrayOfBulkStrings(namespaces); - } else { - auto token = srv->GetNamespace()->Get(args_[2]); - if (token.Is()) { - *output = conn->NilString(); - } else { - *output = redis::BulkString(token.GetValue()); - } + + if (args_[2] == "*") { + std::vector namespaces; + auto tokens = srv->GetNamespace()->List(); + for (auto &token : tokens) { + namespaces.emplace_back(token.second); + namespaces.emplace_back(token.first); } - } else if (args_.size() == 4 && sub_command == "set") { - Status s = srv->GetNamespace()->Set(args_[2], args_[3]); - *output = s.IsOK() ? redis::RESP_OK : redis::Error(s); - WARN("Updated namespace: {} with token: {}, addr: {}, result: {}", args_[2], args_[3], conn->GetAddr(), s.Msg()); - } else if (args_.size() == 4 && sub_command == "add") { - Status s = srv->GetNamespace()->Add(args_[2], args_[3]); - *output = s.IsOK() ? redis::RESP_OK : redis::Error(s); - WARN("New namespace: {} with token: {}, addr: {}, result: {}", args_[2], args_[3], conn->GetAddr(), s.Msg()); - } else if (args_.size() == 3 && sub_command == "del") { - Status s = srv->GetNamespace()->Del(args_[2]); - *output = s.IsOK() ? redis::RESP_OK : redis::Error(s); - WARN("Deleted namespace: {}, addr: {}, result: {}", args_[2], conn->GetAddr(), s.Msg()); - } else if (args_.size() == 2 && sub_command == "current") { - *output = redis::BulkString(conn->GetNamespace()); + namespaces.emplace_back(kDefaultNamespace); + namespaces.emplace_back(config->requirepass); + *output = ArrayOfBulkStrings(namespaces); + return Status::OK(); + } + + auto token = srv->GetNamespace()->Get(args_[2]); + if (token.Is()) { + *output = conn->NilString(); } else { - return {Status::RedisExecErr, "NAMESPACE subcommand must be one of GET, SET, DEL, ADD and CURRENT"}; + *output = redis::BulkString(token.GetValue()); } return Status::OK(); } }; -static uint64_t GenerateNamespaceFlag(uint64_t flags, const std::vector &args) { - if (args.size() >= 2 && util::EqualICase(args[1], "current")) { - return flags & ~kCmdAdmin; +class CommandNamespaceSet : public Commander { + public: + Status Execute([[maybe_unused]] engine::Context &ctx, Server *srv, Connection *conn, std::string *output) override { + if (IsNamespaceReadOnlyOnSlave(srv)) { + return {Status::RedisExecErr, "namespace is read-only for slave"}; + } + if (IsNamespaceCommandDisabled(srv)) { + return {Status::RedisExecErr, "namespace command is not allowed when redis-databases > 0"}; + } + + auto s = srv->GetNamespace()->Set(args_[2], args_[3]); + *output = s.IsOK() ? redis::RESP_OK : redis::Error(s); + WARN("Updated namespace: {} with token: {}, addr: {}, result: {}", args_[2], args_[3], conn->GetAddr(), s.Msg()); + return Status::OK(); } +}; - return flags; -} +class CommandNamespaceAdd : public Commander { + public: + Status Execute([[maybe_unused]] engine::Context &ctx, Server *srv, Connection *conn, std::string *output) override { + if (IsNamespaceReadOnlyOnSlave(srv)) { + return {Status::RedisExecErr, "namespace is read-only for slave"}; + } + if (IsNamespaceCommandDisabled(srv)) { + return {Status::RedisExecErr, "namespace command is not allowed when redis-databases > 0"}; + } + + auto s = srv->GetNamespace()->Add(args_[2], args_[3]); + *output = s.IsOK() ? redis::RESP_OK : redis::Error(s); + WARN("New namespace: {} with token: {}, addr: {}, result: {}", args_[2], args_[3], conn->GetAddr(), s.Msg()); + return Status::OK(); + } +}; + +class CommandNamespaceDel : public Commander { + public: + Status Execute([[maybe_unused]] engine::Context &ctx, Server *srv, Connection *conn, std::string *output) override { + if (IsNamespaceReadOnlyOnSlave(srv)) { + return {Status::RedisExecErr, "namespace is read-only for slave"}; + } + if (IsNamespaceCommandDisabled(srv)) { + return {Status::RedisExecErr, "namespace command is not allowed when redis-databases > 0"}; + } + + auto s = srv->GetNamespace()->Del(args_[2]); + *output = s.IsOK() ? redis::RESP_OK : redis::Error(s); + WARN("Deleted namespace: {}, addr: {}, result: {}", args_[2], conn->GetAddr(), s.Msg()); + return Status::OK(); + } +}; + +class CommandNamespaceCurrent : public Commander { + public: + Status Execute([[maybe_unused]] engine::Context &ctx, Server *srv, Connection *conn, std::string *output) override { + if (IsNamespaceReadOnlyOnSlave(srv)) { + return {Status::RedisExecErr, "namespace is read-only for slave"}; + } + if (IsNamespaceCommandDisabled(srv)) { + return {Status::RedisExecErr, "namespace command is not allowed when redis-databases > 0"}; + } + + *output = redis::BulkString(conn->GetNamespace()); + return Status::OK(); + } +}; class CommandKeys : public Commander { public: @@ -829,13 +890,13 @@ class CommandCommand : public Commander { } else if (sub_command == "info") { CommandTable::GetCommandsInfo(output, std::vector(args_.begin() + 2, args_.end())); } else if (sub_command == "getkeys") { - auto cmd_iter = CommandTable::GetOriginal()->find(util::ToLower(args_[2])); - if (cmd_iter == CommandTable::GetOriginal()->end()) { + std::vector cmd_tokens(args_.begin() + 2, args_.end()); + auto resolved = CommandTable::Resolve(cmd_tokens); + if (!resolved) { return {Status::RedisUnknownCmd, "Invalid command specified"}; } - auto key_indexes = GET_OR_RET(CommandTable::GetKeysFromCommand( - cmd_iter->second, std::vector(args_.begin() + 2, args_.end()))); + auto key_indexes = GET_OR_RET(CommandTable::GetKeysFromCommand(resolved->attributes, cmd_tokens)); if (key_indexes.size() == 0) { return {Status::RedisExecErr, "Invalid arguments specified for command"}; @@ -844,7 +905,7 @@ class CommandCommand : public Commander { std::vector keys; keys.reserve(key_indexes.size()); for (const auto &key_index : key_indexes) { - keys.emplace_back(args_[key_index + 2]); + keys.emplace_back(cmd_tokens[key_index]); } *output = conn->MultiBulkString(keys); } else { @@ -1634,7 +1695,7 @@ REDIS_REGISTER_COMMANDS( MakeCmdAttr("info", -1, "read-only ok-loading", NO_KEY), MakeCmdAttr("role", 1, "read-only ok-loading", NO_KEY), MakeCmdAttr("config", -2, "read-only admin skip-monitor", NO_KEY, GenerateConfigFlag), - MakeCmdAttr("namespace", -2, "read-only admin skip-monitor", NO_KEY, GenerateNamespaceFlag), + MakeCmdAttr("namespace", -2, "read-only admin skip-monitor", NO_KEY), MakeCmdAttr("keys", 2, "read-only slow", NO_KEY), MakeCmdAttr("flushdb", 1, "write no-dbsize-check exclusive", NO_KEY), MakeCmdAttr("flushall", 1, "write no-dbsize-check exclusive admin", NO_KEY), @@ -1670,5 +1731,11 @@ REDIS_REGISTER_COMMANDS( MakeCmdAttr("pollupdates", -2, "read-only admin", NO_KEY), MakeCmdAttr("sst", -3, "write exclusive admin", 1, 1, 1), MakeCmdAttr("flushmemtable", -1, "exclusive write", NO_KEY), - MakeCmdAttr("flushblockcache", 1, "exclusive write", NO_KEY), ) + MakeCmdAttr("flushblockcache", 1, "exclusive write", NO_KEY), + MakeSubCmdAttr("namespace", "get", 3, "read-only admin skip-monitor", NO_KEY), + MakeSubCmdAttr("namespace", "set", 4, "read-only admin skip-monitor", NO_KEY), + MakeSubCmdAttr("namespace", "add", 4, "read-only admin skip-monitor", NO_KEY), + MakeSubCmdAttr("namespace", "del", 3, "read-only admin skip-monitor", NO_KEY), + MakeSubCmdAttr("namespace", "current", 2, "read-only skip-monitor", NO_KEY)) + } // namespace redis diff --git a/src/commands/commander.cc b/src/commands/commander.cc index d820d3fb851..f05eeb7db76 100644 --- a/src/commands/commander.cc +++ b/src/commands/commander.cc @@ -20,11 +20,76 @@ #include "commander.h" +#include + #include "cluster/cluster_defs.h" #include "server/redis_reply.h" namespace redis { +bool CommandTable::isSubcommandName(const std::string &name) { return name.find('|') != std::string::npos; } + +std::pair CommandTable::parseSubcommandName(const std::string &name) { + auto delimiter = name.find('|'); + if (delimiter == std::string::npos || delimiter == 0 || delimiter + 1 >= name.size()) { + std::cout << fmt::format("Encountered invalid subcommand name '{}'", name) << std::endl; + std::abort(); + } + + auto normalized_parent = util::ToLower(name.substr(0, delimiter)); + auto normalized_sub = util::ToLower(name.substr(delimiter + 1)); + if (normalized_sub.empty()) { + std::cout << fmt::format("Encountered invalid subcommand name '{}'", name) << std::endl; + std::abort(); + } + + return {normalized_parent, normalized_sub}; +} + +const CommandAttributes *CommandTable::registerCommand(CommandAttributes attr, CommandCategory category) { + if (original_commands.contains(attr.name) || commands.contains(attr.name)) { + std::cout << fmt::format("Duplicate command registration for '{}'", attr.name) << std::endl; + std::abort(); + } + + attr.category = category; + redis_command_table.emplace_back(std::move(attr)); + auto *registered_attr = &redis_command_table.back(); + original_commands[registered_attr->name] = registered_attr; + commands[registered_attr->name] = registered_attr; + return registered_attr; +} + +const CommandAttributes *CommandTable::registerSubCommand(CommandAttributes attr, CommandCategory category) { + auto [parent, sub] = parseSubcommandName(attr.name); + auto &subcommand_family = sub_commands[parent]; + if (subcommand_family.contains(sub)) { + std::cout << fmt::format("Duplicate subcommand registration for '{}|{}'", parent, sub) << std::endl; + std::abort(); + } + + attr.category = category; + attr.name = fmt::format("{}|{}", parent, sub); + redis_subcommand_table.emplace_back(std::move(attr)); + auto *registered_attr = &redis_subcommand_table.back(); + subcommand_family[sub] = registered_attr; + return registered_attr; +} + +const CommandAttributes *CommandTable::findSubCommand(const std::string &parent, const std::string &sub) { + auto family_iter = sub_commands.find(util::ToLower(parent)); + if (family_iter == sub_commands.end()) { + return nullptr; + } + + auto subcommand_iter = family_iter->second.find(util::ToLower(sub)); + if (subcommand_iter == family_iter->second.end()) { + return nullptr; + } + + return subcommand_iter->second; +} + RegisterToCommandTable::RegisterToCommandTable(CommandCategory category, std::initializer_list list) { if (category == CommandCategory::Disabled) { @@ -32,10 +97,11 @@ RegisterToCommandTable::RegisterToCommandTable(CommandCategory category, } for (auto attr : list) { - attr.category = category; - CommandTable::redis_command_table.emplace_back(attr); - CommandTable::original_commands[attr.name] = &CommandTable::redis_command_table.back(); - CommandTable::commands[attr.name] = &CommandTable::redis_command_table.back(); + if (CommandTable::isSubcommandName(attr.name)) { + CommandTable::registerSubCommand(std::move(attr), category); + continue; + } + CommandTable::registerCommand(std::move(attr), category); } } @@ -83,6 +149,32 @@ void CommandTable::GetCommandsInfo(std::string *info, const std::vector CommandTable::Resolve(const std::vector &cmd_tokens) { + if (cmd_tokens.empty()) { + return {Status::RedisUnknownCmd}; + } + + auto cmd_iter = commands.find(util::ToLower(cmd_tokens.front())); + if (cmd_iter == commands.end()) { + return {Status::RedisUnknownCmd}; + } + + const auto *root_attributes = cmd_iter->second; + ResolvedCommand resolved{root_attributes->name, root_attributes}; + + if (cmd_tokens.size() <= 1) { + return resolved; + } + + auto subcommand_attributes = findSubCommand(root_attributes->name, cmd_tokens[1]); + if (subcommand_attributes == nullptr) { + return resolved; + } + + resolved.attributes = subcommand_attributes; + return resolved; +} + StatusOr> CommandTable::GetKeysFromCommand(const CommandAttributes *attributes, const std::vector &cmd_tokens) { int argc = static_cast(cmd_tokens.size()); @@ -92,7 +184,9 @@ StatusOr> CommandTable::GetKeysFromCommand(const CommandAttribu } auto cmd = attributes->factory(); - if (auto s = cmd->Parse(cmd_tokens); !s) { + cmd->SetAttributes(attributes); + cmd->SetArgs(cmd_tokens); + if (auto s = cmd->Parse(); !s) { return {Status::NotOK, "Invalid syntax found in this command arguments: " + s.Msg()}; } diff --git a/src/commands/commander.h b/src/commands/commander.h index 3f38db02580..7b5e15ae91c 100644 --- a/src/commands/commander.h +++ b/src/commands/commander.h @@ -26,6 +26,7 @@ #include #include +#include #include #include #include @@ -57,6 +58,11 @@ namespace redis { class Connection; struct CommandAttributes; +struct ResolvedCommand { + std::string root; + const CommandAttributes *attributes = nullptr; +}; + enum CommandFlags : uint64_t { // "write" flag, for any command that performs rocksdb writing ops kCmdWrite = 1ULL << 0, @@ -125,6 +131,8 @@ class Commander { public: void SetAttributes(const CommandAttributes *attributes) { attributes_ = attributes; } const CommandAttributes *GetAttributes() const { return attributes_; } + void SetRootName(std::string root_name) { root_name_ = std::move(root_name); } + const std::string &GetRootName() const { return root_name_; } void SetArgs(const std::vector &args) { args_ = args; } virtual Status Parse() { return Parse(args_); } virtual Status Parse([[maybe_unused]] const std::vector &args) { return Status::OK(); } @@ -138,6 +146,7 @@ class Commander { protected: std::vector args_; const CommandAttributes *attributes_ = nullptr; + std::string root_name_; }; class CommanderWithParseMove : Commander { @@ -342,6 +351,7 @@ struct CommandAttributes { }; using CommandMap = std::map; +using SubCommandMap = std::map; inline uint64_t ParseCommandFlags(const std::string &description, const std::string &cmd_name) { uint64_t flags = 0; @@ -414,15 +424,25 @@ auto MakeCmdAttr(const std::string &name, int arity, const std::string &descript } template -auto MakeCmdAttr(const std::string &name, int arity, const std::string &description, int first_key, int last_key, - int key_step = 1, const AdditionalFlagGen &flag_gen = {}) { - CommandAttributes attr(name, arity, CommandCategory::Unknown, ParseCommandFlags(description, name), flag_gen, - {first_key, last_key, key_step}, - []() -> std::unique_ptr { return std::unique_ptr(new T()); }); +auto MakeCmdAttr(const std::string &name, int arity, const std::string &description, CommandKeyRange key_range, + const AdditionalFlagGen &flag_gen = {}) { + CommandAttributes attr{name, + arity, + CommandCategory::Unknown, + ParseCommandFlags(description, name), + flag_gen, + key_range, + []() -> std::unique_ptr { return std::unique_ptr(new T()); }}; return attr; } +template +auto MakeCmdAttr(const std::string &name, int arity, const std::string &description, int first_key, int last_key, + int key_step = 1, const AdditionalFlagGen &flag_gen = {}) { + return MakeCmdAttr(name, arity, description, {first_key, last_key, key_step}, flag_gen); +} + template auto MakeCmdAttr(const std::string &name, int arity, const std::string &description, const CommandKeyRangeGen &gen, const AdditionalFlagGen &flag_gen = {}) { @@ -451,6 +471,40 @@ auto MakeCmdAttr(const std::string &name, int arity, const std::string &descript return attr; } +template +auto MakeSubCmdAttr(const std::string &parent, const std::string &sub, int arity, const std::string &description, + NoKeyInThisCommand no_key, const AdditionalFlagGen &flag_gen = {}) { + return MakeCmdAttr(fmt::format("{}|{}", util::ToLower(parent), util::ToLower(sub)), arity, description, no_key, + flag_gen); +} + +template +auto MakeSubCmdAttr(const std::string &parent, const std::string &sub, int arity, const std::string &description, + CommandKeyRange key_range, const AdditionalFlagGen &flag_gen = {}) { + return MakeCmdAttr(fmt::format("{}|{}", util::ToLower(parent), util::ToLower(sub)), arity, description, key_range, + flag_gen); +} + +template +auto MakeSubCmdAttr(const std::string &parent, const std::string &sub, int arity, const std::string &description, + int first_key, int last_key, int key_step = 1, const AdditionalFlagGen &flag_gen = {}) { + return MakeSubCmdAttr(parent, sub, arity, description, {first_key, last_key, key_step}, flag_gen); +} + +template +auto MakeSubCmdAttr(const std::string &parent, const std::string &sub, int arity, const std::string &description, + const CommandKeyRangeGen &gen, const AdditionalFlagGen &flag_gen = {}) { + return MakeCmdAttr(fmt::format("{}|{}", util::ToLower(parent), util::ToLower(sub)), arity, description, gen, + flag_gen); +} + +template +auto MakeSubCmdAttr(const std::string &parent, const std::string &sub, int arity, const std::string &description, + const CommandKeyRangeVecGen &vec_gen, const AdditionalFlagGen &flag_gen = {}) { + return MakeCmdAttr(fmt::format("{}|{}", util::ToLower(parent), util::ToLower(sub)), arity, description, vec_gen, + flag_gen); +} + struct RegisterToCommandTable { RegisterToCommandTable(CommandCategory category, std::initializer_list list); }; @@ -466,6 +520,7 @@ struct CommandTable { static void GetAllCommandsInfo(std::string *info); static void GetCommandsInfo(std::string *info, const std::vector &cmd_names); static std::string GetCommandInfo(const CommandAttributes *command_attributes); + static StatusOr Resolve(const std::vector &cmd_tokens); static StatusOr> GetKeysFromCommand(const CommandAttributes *attributes, const std::vector &cmd_tokens); @@ -475,7 +530,14 @@ struct CommandTable { static Status ParseSlotRanges(const std::string &slots_str, std::vector &slots); private: + static bool isSubcommandName(const std::string &name); + static std::pair parseSubcommandName(const std::string &name); + static const CommandAttributes *registerCommand(CommandAttributes attr, CommandCategory category); + static const CommandAttributes *registerSubCommand(CommandAttributes attr, CommandCategory category); + static const CommandAttributes *findSubCommand(const std::string &parent, const std::string &sub); + static inline std::deque redis_command_table; + static inline std::deque redis_subcommand_table; // Original Command table before rename-command directive static inline CommandMap original_commands; @@ -483,6 +545,9 @@ struct CommandTable { // Command table after rename-command directive static inline CommandMap commands; + // Subcommand table indexed by root command name and subcommand name. + static inline SubCommandMap sub_commands; + friend struct RegisterToCommandTable; }; diff --git a/src/server/redis_connection.cc b/src/server/redis_connection.cc index f9db92b0ce9..4552bffce2a 100644 --- a/src/server/redis_connection.cc +++ b/src/server/redis_connection.cc @@ -438,7 +438,7 @@ void Connection::ExecuteCommands(std::deque *to_process_cmds) { if (is_multi_exec) multi_error_ = true; }); - auto cmd_s = Server::LookupAndCreateCommand(cmd_tokens.front()); + auto cmd_s = Server::LookupAndCreateCommand(cmd_tokens); if (!cmd_s.IsOK()) { auto cmd_name = cmd_tokens.front(); if (util::EqualICase(cmd_name, "host:") || util::EqualICase(cmd_name, "post")) { @@ -458,7 +458,7 @@ void Connection::ExecuteCommands(std::deque *to_process_cmds) { auto current_cmd = std::move(*cmd_s); const auto &attributes = current_cmd->GetAttributes(); - auto cmd_name = attributes->name; + const auto &cmd_name = current_cmd->GetRootName(); int tokens = static_cast(cmd_tokens.size()); if (!attributes->CheckArity(tokens)) { diff --git a/src/server/server.cc b/src/server/server.cc index fc763863bf0..2e2763b2bc8 100644 --- a/src/server/server.cc +++ b/src/server/server.cc @@ -1863,18 +1863,12 @@ ReplState Server::GetReplicationState() { return kReplConnecting; } -StatusOr> Server::LookupAndCreateCommand(const std::string &cmd_name) { - if (cmd_name.empty()) return {Status::RedisUnknownCmd}; +StatusOr> Server::LookupAndCreateCommand(const std::vector &cmd_tokens) { + auto resolved = GET_OR_RET(redis::CommandTable::Resolve(cmd_tokens)); - auto commands = redis::CommandTable::Get(); - auto cmd_iter = commands->find(util::ToLower(cmd_name)); - if (cmd_iter == commands->end()) { - return {Status::RedisUnknownCmd}; - } - - auto cmd_attr = cmd_iter->second; - auto cmd = cmd_attr->factory(); - cmd->SetAttributes(cmd_attr); + auto cmd = resolved.attributes->factory(); + cmd->SetAttributes(resolved.attributes); + cmd->SetRootName(std::move(resolved.root)); return std::move(cmd); } diff --git a/src/server/server.h b/src/server/server.h index 63b83e28306..8396b5cbbb2 100644 --- a/src/server/server.h +++ b/src/server/server.h @@ -200,7 +200,7 @@ class Server { bool IsStopped() const { return stop_; } bool IsLoading() const { return is_loading_; } Config *GetConfig() { return config_; } - static StatusOr> LookupAndCreateCommand(const std::string &cmd_name); + static StatusOr> LookupAndCreateCommand(const std::vector &cmd_tokens); void AdjustOpenFilesLimit(); void AdjustWorkerThreads(); diff --git a/src/storage/scripting.cc b/src/storage/scripting.cc index d1175b04ec8..5d525f3eeea 100644 --- a/src/storage/scripting.cc +++ b/src/storage/scripting.cc @@ -784,7 +784,7 @@ int RedisGenericCommand(lua_State *lua, int raise_error) { } } - auto cmd_s = Server::LookupAndCreateCommand(args[0]); + auto cmd_s = Server::LookupAndCreateCommand(args); if (!cmd_s) { PushError(lua, "Unknown Redis command called from Lua script"); return raise_error ? RaiseError(lua) : 1; @@ -792,6 +792,7 @@ int RedisGenericCommand(lua_State *lua, int raise_error) { auto cmd = *std::move(cmd_s); auto attributes = cmd->GetAttributes(); + const auto &cmd_name = cmd->GetRootName(); if (!attributes->CheckArity(argc)) { PushError(lua, "Wrong number of args while calling Redis command from Lua script"); return raise_error ? RaiseError(lua) : 1; @@ -813,7 +814,6 @@ int RedisGenericCommand(lua_State *lua, int raise_error) { return raise_error ? RaiseError(lua) : 1; } - std::string cmd_name = attributes->name; cmd->SetArgs(args); auto s = cmd->Parse(); if (!s) { diff --git a/tests/cppunit/subcommand_resolution_test.cc b/tests/cppunit/subcommand_resolution_test.cc new file mode 100644 index 00000000000..0cdcb1c096c --- /dev/null +++ b/tests/cppunit/subcommand_resolution_test.cc @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#include + +#include "commands/commander.h" +#include "common/scope_exit.h" +namespace { + +using redis::CommandCategory; +using redis::MakeCmdAttr; +using redis::MakeSubCmdAttr; +using redis::NO_KEY; +using redis::RegisterToCommandTable; + +class TestCommandRoot : public redis::Commander {}; +class TestCommandSub : public redis::Commander {}; + +REDIS_REGISTER_COMMANDS(Server, MakeCmdAttr("subcommandkeyrangetest", -2, "read-only", NO_KEY), + MakeSubCmdAttr("subcommandkeyrangetest", "load", -3, "read-only", 2, -1, 1)) + +// root command resolve +TEST(SubcommandResolution, ResolveRootCommandWithoutSubcommand) { + auto resolved = redis::CommandTable::Resolve({"ping"}); + + ASSERT_TRUE(resolved); + EXPECT_EQ(resolved->root, "ping"); + ASSERT_NE(resolved->attributes, nullptr); + EXPECT_EQ(resolved->attributes->name, "ping"); +} + +// registered namespace subcommand +TEST(SubcommandResolution, ResolveRegisteredNamespaceSubcommand) { + auto resolved = redis::CommandTable::Resolve({"namespace", "add", "ns-1", "token-1"}); + + ASSERT_TRUE(resolved); + EXPECT_EQ(resolved->root, "namespace"); + ASSERT_NE(resolved->attributes, nullptr); + EXPECT_EQ(resolved->attributes->name, "namespace|add"); + EXPECT_EQ(resolved->attributes->arity, 4); +} + +// subcommand key range +TEST(SubcommandResolution, RegisteredSubcommandUsesSubcommandKeyRange) { + std::vector cmd_tokens = {"subcommandkeyrangetest", "load", "key-1", "key-2"}; + + auto resolved = redis::CommandTable::Resolve(cmd_tokens); + ASSERT_TRUE(resolved); + EXPECT_EQ(resolved->root, "subcommandkeyrangetest"); + ASSERT_NE(resolved->attributes, nullptr); + EXPECT_EQ(resolved->attributes->name, "subcommandkeyrangetest|load"); + + auto key_indexes = redis::CommandTable::GetKeysFromCommand(resolved->attributes, cmd_tokens); + ASSERT_TRUE(key_indexes); + EXPECT_EQ(*key_indexes, (std::vector{2, 3})); +} + +// missing namespace subcommand arity +TEST(SubcommandResolution, MissingNamespaceSubcommandRejectsArity) { + auto resolved = redis::CommandTable::Resolve({"namespace"}); + + ASSERT_TRUE(resolved); + EXPECT_EQ(resolved->root, "namespace"); + ASSERT_NE(resolved->attributes, nullptr); + EXPECT_EQ(resolved->attributes->name, "namespace"); + EXPECT_FALSE(resolved->attributes->CheckArity(1)); +} + +// renamed root command resolve +TEST(SubcommandResolution, ResolveRenamedRootCommandWithSubcommand) { + auto reset_guard = MakeScopeExit([] { redis::CommandTable::Reset(); }); + + auto *commands = redis::CommandTable::Get(); + auto command_iter = commands->find("namespace"); + ASSERT_NE(command_iter, commands->end()); + (*commands)["ns"] = command_iter->second; + commands->erase(command_iter); + + auto resolved = redis::CommandTable::Resolve({"ns", "add", "ns-1", "token-1"}); + + ASSERT_TRUE(resolved); + EXPECT_EQ(resolved->root, "namespace"); + ASSERT_NE(resolved->attributes, nullptr); + EXPECT_EQ(resolved->attributes->name, "namespace|add"); +} + +// renamed root command key range +TEST(SubcommandResolution, RenamedRootCommandUsesSubcommandKeyRange) { + auto reset_guard = MakeScopeExit([] { redis::CommandTable::Reset(); }); + + auto *commands = redis::CommandTable::Get(); + auto command_iter = commands->find("subcommandkeyrangetest"); + ASSERT_NE(command_iter, commands->end()); + (*commands)["renamedsubcommandkeyrangetest"] = command_iter->second; + commands->erase(command_iter); + + std::vector cmd_tokens = {"renamedsubcommandkeyrangetest", "load", "key-1", "key-2"}; + + auto resolved = redis::CommandTable::Resolve(cmd_tokens); + ASSERT_TRUE(resolved); + EXPECT_EQ(resolved->root, "subcommandkeyrangetest"); + ASSERT_NE(resolved->attributes, nullptr); + EXPECT_EQ(resolved->attributes->name, "subcommandkeyrangetest|load"); + + auto key_indexes = redis::CommandTable::GetKeysFromCommand(resolved->attributes, cmd_tokens); + ASSERT_TRUE(key_indexes); + EXPECT_EQ(*key_indexes, (std::vector{2, 3})); +} + +} // namespace diff --git a/tests/gocase/unit/command/command_test.go b/tests/gocase/unit/command/command_test.go index 9a065c28124..08a6e41db80 100644 --- a/tests/gocase/unit/command/command_test.go +++ b/tests/gocase/unit/command/command_test.go @@ -433,3 +433,22 @@ func TestCommand(t *testing.T) { } }) } + +// renamed root command GETKEYS +func TestCommandGetKeysWithRenamedCommand(t *testing.T) { + srv := util.StartServer(t, map[string]string{ + "rename-command MGET": "RENAMED_MGET", + }) + defer srv.Close() + + ctx := context.Background() + rdb := srv.NewClient() + defer func() { require.NoError(t, rdb.Close()) }() + + r := rdb.Do(ctx, "COMMAND", "GETKEYS", "RENAMED_MGET", "k1", "k2") + vs, err := r.Slice() + require.NoError(t, err) + require.Len(t, vs, 2) + require.Equal(t, "k1", vs[0]) + require.Equal(t, "k2", vs[1]) +} diff --git a/tests/gocase/unit/namespace/namespace_test.go b/tests/gocase/unit/namespace/namespace_test.go index 663bcbe9e77..aff609c69d7 100644 --- a/tests/gocase/unit/namespace/namespace_test.go +++ b/tests/gocase/unit/namespace/namespace_test.go @@ -180,6 +180,40 @@ func TestNamespace(t *testing.T) { }) } +// unknown namespace subcommands +func TestCommandNamespaceSubcommands(t *testing.T) { + srv := util.StartServer(t, map[string]string{}) + defer srv.Close() + + ctx := context.Background() + rdb := srv.NewClient() + defer func() { require.NoError(t, rdb.Close()) }() + + assertInvalidNamespaceSubcommand := func(t *testing.T, args ...string) { + t.Helper() + + commandArgs := make([]interface{}, len(args)) + for i, arg := range args { + commandArgs[i] = arg + } + + err := rdb.Do(ctx, commandArgs...).Err() + require.Error(t, err) + require.Contains(t, err.Error(), "NAMESPACE subcommand must be one of") + for _, subcommand := range []string{"ADD", "CURRENT", "DEL", "GET", "SET"} { + require.Contains(t, err.Error(), subcommand) + } + } + + // legacy invalid subcommand error + t.Run("NAMESPACE keeps legacy invalid subcommand error", func(t *testing.T) { + assertInvalidNamespaceSubcommand(t, "NAMESPACE", "MISSING") + assertInvalidNamespaceSubcommand(t, "NAMESPACE", "MISSING", "arg1") + + }) + +} + func TestNamespaceReplicate(t *testing.T) { password := "pwd" masterSrv := util.StartServer(t, map[string]string{