Skip to content

Commit 2f30580

Browse files
authored
Merge pull request #181 from vincenzopalazzo/claude/infallible-davinci-0b72bb
Scope MCP tool calls to the authenticated tenant
2 parents 9225e24 + 8c2d056 commit 2f30580

3 files changed

Lines changed: 271 additions & 9 deletions

File tree

.github/workflows/ci.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ jobs:
5454
npm ci
5555
5656
- name: Run Verification Test
57+
# `tests/test_omnibus.ts` was ported to vitest as
58+
# `tests/omnibus.test.ts` in cbc8c84; this step has been failing
59+
# ever since. Run the whole vitest suite so omnibus + the rest of
60+
# the per-tenant / temporal / multilingual specs are exercised.
5761
run: |
5862
cd packages/openmemory-js
59-
npx tsx tests/test_omnibus.ts
63+
npm test

packages/openmemory-js/src/ai/mcp.ts

Lines changed: 65 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,42 @@ const send_err = (
8181

8282
const uid = (val?: string | null) => (val?.trim() ? val.trim() : undefined);
8383

84-
export const create_mcp_srv = () => {
84+
/**
85+
* Resolve the effective user_id for a tool call.
86+
*
87+
* The HTTP MCP route runs through the same `authenticate_api_request`
88+
* middleware as the REST routes (see src/server/index.ts), so for every
89+
* authenticated MCP call we have a `tenant` derived from the API key.
90+
* That tenant is the source of truth for ownership, mirroring the REST
91+
* `require_tenant` + `reject_tenant_mismatch` model.
92+
*
93+
* - tenant set + no arg -> use tenant
94+
* - tenant set + matching arg -> use tenant
95+
* - tenant set + mismatching arg -> throw (becomes an MCP isError)
96+
* - tenant unset (stdio transport etc) -> use the arg as supplied
97+
*
98+
* Stdio MCP keeps its existing behaviour — there is no HTTP request to
99+
* carry an API key, so tenant is undefined and the tool falls back to
100+
* whatever user_id the client passed (or `add_hsg_memory`'s "anonymous"
101+
* default).
102+
*/
103+
const resolve_user_id = (
104+
tenant: string | undefined,
105+
arg: string | null | undefined,
106+
): string | undefined => {
107+
const trimmed = uid(arg);
108+
if (tenant) {
109+
if (trimmed && trimmed !== tenant) {
110+
throw new Error(
111+
"tenant_mismatch: user_id does not match authenticated tenant; omit user_id or pass the tenant identifier",
112+
);
113+
}
114+
return tenant;
115+
}
116+
return trimmed;
117+
};
118+
119+
export const create_mcp_srv = (tenant?: string) => {
85120
const srv = new McpServer(
86121
{
87122
name: "openmemory-mcp",
@@ -182,7 +217,7 @@ export const create_mcp_srv = () => {
182217
user_id,
183218
project_id,
184219
}) => {
185-
const u = uid(user_id);
220+
const u = resolve_user_id(tenant, user_id);
186221
const proj = uid(project_id);
187222
const results: any = { type, query };
188223
const at_date = at ? new Date(at) : new Date();
@@ -368,7 +403,7 @@ export const create_mcp_srv = () => {
368403
metadata,
369404
user_id,
370405
}) => {
371-
const u = uid(user_id);
406+
const u = resolve_user_id(tenant, user_id);
372407
const proj = uid(project_id);
373408
const results: any = { type };
374409

@@ -485,7 +520,7 @@ export const create_mcp_srv = () => {
485520
metadata,
486521
user_id,
487522
}) => {
488-
const u = uid(user_id);
523+
const u = resolve_user_id(tenant, user_id);
489524
// Force global scope for this tool
490525
const proj = "system_global";
491526
const results: any = { type };
@@ -571,6 +606,15 @@ export const create_mcp_srv = () => {
571606
.describe("Salience boost amount (default 0.1)"),
572607
},
573608
async ({ id, boost }) => {
609+
if (tenant) {
610+
// When HTTP-bound, refuse to reinforce another tenant's memory.
611+
const mem = await q.get_mem.get(id);
612+
if (!mem || mem.user_id !== tenant) {
613+
throw new Error(
614+
`Memory ${id} not found for user ${tenant}`,
615+
);
616+
}
617+
}
574618
await reinforce_memory(id, boost);
575619
return {
576620
content: [
@@ -602,7 +646,7 @@ export const create_mcp_srv = () => {
602646
.describe("Validate project identifier"),
603647
},
604648
async ({ id, user_id, project_id }) => {
605-
const u = uid(user_id);
649+
const u = resolve_user_id(tenant, user_id);
606650
const proj = uid(project_id);
607651
if (u || proj) {
608652
// Pre-check ownership if user_id/project_id provided
@@ -675,7 +719,7 @@ export const create_mcp_srv = () => {
675719
.describe("Restrict results to a specific project identifier"),
676720
},
677721
async ({ limit, sector, user_id, project_id }) => {
678-
const u = uid(user_id);
722+
const u = resolve_user_id(tenant, user_id);
679723
const proj = uid(project_id);
680724
let rows: mem_row[];
681725

@@ -750,7 +794,7 @@ export const create_mcp_srv = () => {
750794
),
751795
},
752796
async ({ id, include_vectors, user_id }) => {
753-
const u = uid(user_id);
797+
const u = resolve_user_id(tenant, user_id);
754798
const mem = await q.get_mem.get(id);
755799
if (!mem)
756800
return {
@@ -875,7 +919,20 @@ export const mcp = (app: any) => {
875919
// Create a fresh transport + server per request to support
876920
// multiple clients (MCP SDK 1.27 rejects re-initialization
877921
// on a single transport instance).
878-
const srv = create_mcp_srv();
922+
//
923+
// `req.tenant` is set by the global `authenticate_api_request`
924+
// middleware (src/server/index.ts). Threading it into the
925+
// per-request server is what scopes MCP tool calls to the
926+
// authenticated tenant — without this, tools either wrote
927+
// memories with user_id="anonymous" (invisible to REST
928+
// `/memory/all` which is tenant-scoped) or read across
929+
// every tenant. See resolve_user_id() for the per-tool
930+
// contract.
931+
const tenant_from_req =
932+
typeof (req as any).tenant === "string"
933+
? ((req as any).tenant as string)
934+
: undefined;
935+
const srv = create_mcp_srv(tenant_from_req);
879936
const trans = new StreamableHTTPServerTransport({
880937
sessionIdGenerator: undefined,
881938
enableJsonResponse: true,
Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
// Force synthetic embeddings + sqlite backend BEFORE importing anything
2+
// that loads cfg/db. vitest.config.ts already sets these via env, but keep
3+
// this guard for standalone tsx runs.
4+
process.env.OM_EMBEDDINGS = "synthetic";
5+
process.env.OM_EMBEDDING_FALLBACK = "synthetic";
6+
process.env.OM_METADATA_BACKEND = process.env.OM_METADATA_BACKEND || "sqlite";
7+
process.env.OM_VECTOR_BACKEND = process.env.OM_VECTOR_BACKEND || "sqlite";
8+
9+
import { beforeEach, describe, expect, it } from "vitest";
10+
import { Client } from "@modelcontextprotocol/sdk/client/index.js";
11+
import { InMemoryTransport } from "@modelcontextprotocol/sdk/inMemory.js";
12+
import { create_mcp_srv } from "../src/ai/mcp";
13+
import { run_async, q } from "../src/core/db";
14+
15+
const T_ALICE = "tenant-alice-mcp";
16+
const T_BOB = "tenant-bob-mcp";
17+
18+
async function cleanup() {
19+
await run_async(`DELETE FROM memories`);
20+
try {
21+
await run_async(`DELETE FROM vectors`);
22+
} catch {
23+
/* schema variant */
24+
}
25+
try {
26+
await run_async(`DELETE FROM openmemory_vectors`);
27+
} catch {
28+
/* schema variant */
29+
}
30+
try {
31+
await run_async(`DELETE FROM waypoints`);
32+
} catch {
33+
/* schema variant */
34+
}
35+
}
36+
37+
async function connect_client(tenant?: string) {
38+
const srv = create_mcp_srv(tenant);
39+
const [client_transport, server_transport] =
40+
InMemoryTransport.createLinkedPair();
41+
await srv.connect(server_transport);
42+
const client = new Client({ name: "test-client", version: "0.0.0" });
43+
await client.connect(client_transport);
44+
return { client, srv };
45+
}
46+
47+
function parse_items(result: any): Array<{ id: string; user_id?: string }> {
48+
// openmemory_list returns two text blocks; the second is a JSON dump.
49+
const blocks = (result?.content ?? []) as Array<{
50+
type: string;
51+
text: string;
52+
}>;
53+
const jsonBlock = blocks
54+
.filter((b) => b.type === "text")
55+
.map((b) => b.text)
56+
.find((t) => t.trim().startsWith("{"));
57+
if (!jsonBlock) return [];
58+
const parsed = JSON.parse(jsonBlock);
59+
return parsed.items ?? [];
60+
}
61+
62+
function parse_store(result: any): { id?: string; project_id?: string } {
63+
const blocks = (result?.content ?? []) as Array<{
64+
type: string;
65+
text: string;
66+
}>;
67+
const jsonBlock = blocks
68+
.filter((b) => b.type === "text")
69+
.map((b) => b.text)
70+
.find((t) => t.trim().startsWith("{"));
71+
if (!jsonBlock) return {};
72+
const parsed = JSON.parse(jsonBlock);
73+
return { id: parsed?.hsg?.id, project_id: parsed?.project_id };
74+
}
75+
76+
describe("MCP per-tenant scoping", () => {
77+
beforeEach(async () => {
78+
await cleanup();
79+
});
80+
81+
it("openmemory_store binds writes to the authenticated tenant", async () => {
82+
const { client } = await connect_client(T_ALICE);
83+
const stored = await client.callTool({
84+
name: "openmemory_store",
85+
arguments: {
86+
content:
87+
"Nginx 502 on a fresh VM: check that the upstream service is actually running before looking at nginx config.",
88+
tags: ["nginx", "sysadmin"],
89+
},
90+
});
91+
const { id } = parse_store(stored);
92+
expect(id).toBeTruthy();
93+
94+
// The DB row must carry the tenant as user_id — without this fix
95+
// it would have been "anonymous" and invisible to REST /memory/all.
96+
const row = await q.get_mem.get(id!);
97+
expect(row).toBeTruthy();
98+
expect(row.user_id).toBe(T_ALICE);
99+
expect(row.project_id).toBe("system_global");
100+
});
101+
102+
it("openmemory_list returns the tenant's own MCP-stored memories (regression)", async () => {
103+
// Reproduces the symptom from the bug report: a memory stored via
104+
// MCP openmemory_store must appear in MCP openmemory_list on the
105+
// same authenticated session.
106+
const { client } = await connect_client(T_ALICE);
107+
108+
await client.callTool({
109+
name: "openmemory_store",
110+
arguments: {
111+
content:
112+
"Nginx 502 on a fresh VM: check the upstream service is running before touching nginx config.",
113+
tags: ["nginx"],
114+
},
115+
});
116+
117+
const listed = await client.callTool({
118+
name: "openmemory_list",
119+
arguments: { limit: 50 },
120+
});
121+
const items = parse_items(listed);
122+
expect(items.length).toBeGreaterThan(0);
123+
expect(items.every((i) => i.user_id === T_ALICE)).toBe(true);
124+
});
125+
126+
it("openmemory_list isolates tenants from each other", async () => {
127+
const alice = await connect_client(T_ALICE);
128+
const bob = await connect_client(T_BOB);
129+
130+
await alice.client.callTool({
131+
name: "openmemory_store",
132+
arguments: { content: "Alice's private dev notes about nginx." },
133+
});
134+
await bob.client.callTool({
135+
name: "openmemory_store",
136+
arguments: { content: "Bob's private dev notes about postgres." },
137+
});
138+
139+
const bob_list = parse_items(
140+
await bob.client.callTool({
141+
name: "openmemory_list",
142+
arguments: { limit: 50 },
143+
}),
144+
);
145+
// Bob must not see Alice's memories.
146+
expect(bob_list.every((i) => i.user_id === T_BOB)).toBe(true);
147+
expect(bob_list.length).toBe(1);
148+
149+
const alice_list = parse_items(
150+
await alice.client.callTool({
151+
name: "openmemory_list",
152+
arguments: { limit: 50 },
153+
}),
154+
);
155+
expect(alice_list.every((i) => i.user_id === T_ALICE)).toBe(true);
156+
expect(alice_list.length).toBe(1);
157+
});
158+
159+
it("openmemory_store rejects a user_id arg that disagrees with the tenant", async () => {
160+
const { client } = await connect_client(T_ALICE);
161+
const result: any = await client.callTool({
162+
name: "openmemory_store",
163+
arguments: {
164+
content: "attempt to forge another tenant's identity",
165+
user_id: T_BOB,
166+
},
167+
});
168+
// ToolRegistry catches errors and turns them into an isError result
169+
// with a textual "Error: ..." block.
170+
expect(result.isError).toBe(true);
171+
const text = (result.content ?? [])
172+
.map((b: any) => b.text ?? "")
173+
.join("\n");
174+
expect(text).toMatch(/tenant_mismatch/);
175+
});
176+
177+
it("stdio-style server (no tenant) preserves legacy behaviour", async () => {
178+
// No tenant bound — this is the stdio MCP shape. Stored memories
179+
// get the "anonymous" fallback from add_hsg_memory and openmemory_list
180+
// returns everything in the table (the pre-existing local-dev contract).
181+
const { client } = await connect_client(undefined);
182+
const stored = await client.callTool({
183+
name: "openmemory_store",
184+
arguments: { content: "stdio-mode memory with no tenant binding" },
185+
});
186+
const { id } = parse_store(stored);
187+
expect(id).toBeTruthy();
188+
189+
const row = await q.get_mem.get(id!);
190+
expect(row.user_id).toBe("anonymous");
191+
192+
const items = parse_items(
193+
await client.callTool({
194+
name: "openmemory_list",
195+
arguments: { limit: 50 },
196+
}),
197+
);
198+
expect(items.length).toBe(1);
199+
expect(items[0].id).toBe(id);
200+
});
201+
});

0 commit comments

Comments
 (0)