|
| 1 | +/** |
| 2 | + * mcpd — Role-Based Access Control (RBAC) for tools |
| 3 | + * |
| 4 | + * Associates API keys with roles, and restricts tool access by role. |
| 5 | + * When enabled, tool calls are checked against the caller's role. |
| 6 | + * Unauthenticated callers get the "guest" role (configurable). |
| 7 | + * |
| 8 | + * Usage: |
| 9 | + * auto& ac = server.accessControl(); |
| 10 | + * ac.addRole("admin"); |
| 11 | + * ac.addRole("viewer"); |
| 12 | + * ac.mapKeyToRole("secret-admin-key", "admin"); |
| 13 | + * ac.mapKeyToRole("read-only-key", "viewer"); |
| 14 | + * ac.restrictTool("gpio_write", {"admin"}); |
| 15 | + * ac.restrictTool("gpio_read", {"admin", "viewer"}); |
| 16 | + * // Tools without restrictions are accessible to all roles |
| 17 | + */ |
| 18 | + |
| 19 | +#ifndef MCPD_ACCESS_CONTROL_H |
| 20 | +#define MCPD_ACCESS_CONTROL_H |
| 21 | + |
| 22 | +#include <Arduino.h> |
| 23 | +#include <functional> |
| 24 | +#include <map> |
| 25 | +#include <set> |
| 26 | +#include <vector> |
| 27 | +#include <string> |
| 28 | + |
| 29 | +namespace mcpd { |
| 30 | + |
| 31 | +class AccessControl { |
| 32 | +public: |
| 33 | + AccessControl() = default; |
| 34 | + |
| 35 | + // ── Role Management ────────────────────────────────────────────── |
| 36 | + |
| 37 | + /** Add a role definition. */ |
| 38 | + void addRole(const char* role) { |
| 39 | + _roles.insert(String(role)); |
| 40 | + } |
| 41 | + |
| 42 | + /** Remove a role definition and all its associations. */ |
| 43 | + void removeRole(const char* role) { |
| 44 | + String r(role); |
| 45 | + _roles.erase(r); |
| 46 | + // Remove key→role mappings for this role |
| 47 | + for (auto it = _keyToRole.begin(); it != _keyToRole.end(); ) { |
| 48 | + if (it->second == r) { |
| 49 | + it = _keyToRole.erase(it); |
| 50 | + } else { |
| 51 | + ++it; |
| 52 | + } |
| 53 | + } |
| 54 | + // Remove from tool restrictions |
| 55 | + for (auto& pair : _toolRoles) { |
| 56 | + pair.second.erase(r); |
| 57 | + } |
| 58 | + } |
| 59 | + |
| 60 | + /** Check if a role exists. */ |
| 61 | + bool hasRole(const char* role) const { |
| 62 | + return _roles.count(String(role)) > 0; |
| 63 | + } |
| 64 | + |
| 65 | + /** Get all defined roles. */ |
| 66 | + std::set<String> roles() const { return _roles; } |
| 67 | + |
| 68 | + // ── Key-to-Role Mapping ────────────────────────────────────────── |
| 69 | + |
| 70 | + /** Map an API key to a role. A key can only have one role. */ |
| 71 | + void mapKeyToRole(const char* apiKey, const char* role) { |
| 72 | + _keyToRole[String(apiKey)] = String(role); |
| 73 | + _roles.insert(String(role)); // Auto-add role if not yet defined |
| 74 | + } |
| 75 | + |
| 76 | + /** Remove a key mapping. */ |
| 77 | + void unmapKey(const char* apiKey) { |
| 78 | + _keyToRole.erase(String(apiKey)); |
| 79 | + } |
| 80 | + |
| 81 | + /** Get the role for a key, or empty string if not mapped. */ |
| 82 | + String roleForKey(const char* apiKey) const { |
| 83 | + auto it = _keyToRole.find(String(apiKey)); |
| 84 | + if (it != _keyToRole.end()) return it->second; |
| 85 | + return String(); |
| 86 | + } |
| 87 | + |
| 88 | + // ── Tool Restrictions ──────────────────────────────────────────── |
| 89 | + |
| 90 | + /** |
| 91 | + * Restrict a tool to specific roles. |
| 92 | + * Only callers with one of the listed roles can call this tool. |
| 93 | + * An empty set means the tool is restricted to no one (effectively disabled). |
| 94 | + */ |
| 95 | + void restrictTool(const char* toolName, const std::vector<const char*>& allowedRoles) { |
| 96 | + std::set<String> roleSet; |
| 97 | + for (auto r : allowedRoles) roleSet.insert(String(r)); |
| 98 | + _toolRoles[String(toolName)] = roleSet; |
| 99 | + } |
| 100 | + |
| 101 | + /** Overload accepting String set. */ |
| 102 | + void restrictToolSet(const char* toolName, const std::set<String>& allowedRoles) { |
| 103 | + _toolRoles[String(toolName)] = allowedRoles; |
| 104 | + } |
| 105 | + |
| 106 | + /** Remove restrictions from a tool (makes it accessible to all). */ |
| 107 | + void unrestrictTool(const char* toolName) { |
| 108 | + _toolRoles.erase(String(toolName)); |
| 109 | + } |
| 110 | + |
| 111 | + /** Check if a tool has restrictions. */ |
| 112 | + bool isToolRestricted(const char* toolName) const { |
| 113 | + return _toolRoles.count(String(toolName)) > 0; |
| 114 | + } |
| 115 | + |
| 116 | + /** Get allowed roles for a tool (empty set = unrestricted). */ |
| 117 | + std::set<String> toolAllowedRoles(const char* toolName) const { |
| 118 | + auto it = _toolRoles.find(String(toolName)); |
| 119 | + if (it != _toolRoles.end()) return it->second; |
| 120 | + return {}; |
| 121 | + } |
| 122 | + |
| 123 | + // ── Access Checks ──────────────────────────────────────────────── |
| 124 | + |
| 125 | + /** Set the default role for unauthenticated/unmapped callers. */ |
| 126 | + void setDefaultRole(const char* role) { |
| 127 | + _defaultRole = String(role); |
| 128 | + _roles.insert(String(role)); |
| 129 | + } |
| 130 | + |
| 131 | + /** Get the default role. */ |
| 132 | + String defaultRole() const { return _defaultRole; } |
| 133 | + |
| 134 | + /** |
| 135 | + * Check if a caller with the given API key can access a tool. |
| 136 | + * - If the tool has no restrictions, returns true. |
| 137 | + * - If the tool is restricted, checks the key's role (or default role). |
| 138 | + * - If RBAC is not enabled, always returns true. |
| 139 | + */ |
| 140 | + bool canAccess(const char* toolName, const char* apiKey = nullptr) const { |
| 141 | + if (!_enabled) return true; |
| 142 | + |
| 143 | + // Tool not restricted → allow |
| 144 | + auto toolIt = _toolRoles.find(String(toolName)); |
| 145 | + if (toolIt == _toolRoles.end()) return true; |
| 146 | + |
| 147 | + // Determine caller's role |
| 148 | + String callerRole = _defaultRole; |
| 149 | + if (apiKey && apiKey[0] != '\0') { |
| 150 | + auto keyIt = _keyToRole.find(String(apiKey)); |
| 151 | + if (keyIt != _keyToRole.end()) { |
| 152 | + callerRole = keyIt->second; |
| 153 | + } |
| 154 | + } |
| 155 | + |
| 156 | + // Check if caller's role is in allowed set |
| 157 | + if (callerRole.isEmpty()) return false; |
| 158 | + return toolIt->second.count(callerRole) > 0; |
| 159 | + } |
| 160 | + |
| 161 | + /** Enable/disable RBAC checking. When disabled, canAccess always returns true. */ |
| 162 | + void enable(bool v = true) { _enabled = v; } |
| 163 | + void disable() { _enabled = false; } |
| 164 | + bool isEnabled() const { return _enabled; } |
| 165 | + |
| 166 | + // ── Bulk Operations ────────────────────────────────────────────── |
| 167 | + |
| 168 | + /** |
| 169 | + * Restrict all tools with a given annotation to specific roles. |
| 170 | + * Useful for restricting all destructive tools to admin only. |
| 171 | + */ |
| 172 | + void restrictDestructiveTools(const std::vector<const char*>& toolNames, |
| 173 | + const std::vector<const char*>& allowedRoles) { |
| 174 | + for (auto name : toolNames) { |
| 175 | + restrictTool(name, allowedRoles); |
| 176 | + } |
| 177 | + } |
| 178 | + |
| 179 | + /** |
| 180 | + * Get the list of tools accessible to a given role. |
| 181 | + * Returns tool names that are either unrestricted or explicitly allow this role. |
| 182 | + */ |
| 183 | + std::vector<String> toolsForRole(const char* role, const std::vector<String>& allTools) const { |
| 184 | + String r(role); |
| 185 | + std::vector<String> result; |
| 186 | + for (const auto& tool : allTools) { |
| 187 | + auto it = _toolRoles.find(tool); |
| 188 | + if (it == _toolRoles.end()) { |
| 189 | + // Unrestricted |
| 190 | + result.push_back(tool); |
| 191 | + } else if (it->second.count(r) > 0) { |
| 192 | + result.push_back(tool); |
| 193 | + } |
| 194 | + } |
| 195 | + return result; |
| 196 | + } |
| 197 | + |
| 198 | + // ── JSON Serialization ─────────────────────────────────────────── |
| 199 | + |
| 200 | + /** Serialize RBAC config to JSON string. */ |
| 201 | + String toJSON() const { |
| 202 | + String json = "{\"enabled\":"; |
| 203 | + json += _enabled ? "true" : "false"; |
| 204 | + json += ",\"defaultRole\":\""; |
| 205 | + json += _defaultRole; |
| 206 | + json += "\",\"roles\":["; |
| 207 | + bool first = true; |
| 208 | + for (const auto& r : _roles) { |
| 209 | + if (!first) json += ","; |
| 210 | + json += "\"" + r + "\""; |
| 211 | + first = false; |
| 212 | + } |
| 213 | + json += "],\"toolRestrictions\":{"; |
| 214 | + first = true; |
| 215 | + for (const auto& pair : _toolRoles) { |
| 216 | + if (!first) json += ","; |
| 217 | + json += "\"" + pair.first + "\":["; |
| 218 | + bool f2 = true; |
| 219 | + for (const auto& r : pair.second) { |
| 220 | + if (!f2) json += ","; |
| 221 | + json += "\"" + r + "\""; |
| 222 | + f2 = false; |
| 223 | + } |
| 224 | + json += "]"; |
| 225 | + first = false; |
| 226 | + } |
| 227 | + json += "},\"keyMappings\":"; |
| 228 | + json += String((int)_keyToRole.size()); |
| 229 | + json += "}"; |
| 230 | + return json; |
| 231 | + } |
| 232 | + |
| 233 | + /** Get stats as JSON. */ |
| 234 | + String statsJSON() const { |
| 235 | + String json = "{\"enabled\":"; |
| 236 | + json += _enabled ? "true" : "false"; |
| 237 | + json += ",\"roles\":"; |
| 238 | + json += String((int)_roles.size()); |
| 239 | + json += ",\"keyMappings\":"; |
| 240 | + json += String((int)_keyToRole.size()); |
| 241 | + json += ",\"restrictedTools\":"; |
| 242 | + json += String((int)_toolRoles.size()); |
| 243 | + json += "}"; |
| 244 | + return json; |
| 245 | + } |
| 246 | + |
| 247 | +private: |
| 248 | + bool _enabled = false; |
| 249 | + String _defaultRole = "guest"; |
| 250 | + std::set<String> _roles; |
| 251 | + std::map<String, String> _keyToRole; // API key → role |
| 252 | + std::map<String, std::set<String>> _toolRoles; // tool name → allowed roles |
| 253 | +}; |
| 254 | + |
| 255 | +} // namespace mcpd |
| 256 | + |
| 257 | +#endif // MCPD_ACCESS_CONTROL_H |
0 commit comments