Skip to content

Commit 9dbe7fb

Browse files
Merge pull request #45 from authzed/new_langchain_library
refactor(agentic-rag): Use langchain-spicedb SpiceDBAuthorizer for authorization
2 parents 40b92fc + b56e661 commit 9dbe7fb

6 files changed

Lines changed: 69 additions & 276 deletions

File tree

agentic-rag-authorization/.env.example

Lines changed: 0 additions & 17 deletions
This file was deleted.

agentic-rag-authorization/agentic_rag/authorization_helpers.py

Lines changed: 0 additions & 105 deletions
This file was deleted.
Lines changed: 8 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -1,129 +1,41 @@
1-
"""Helper functions for gRPC and SpiceDB authentication."""
1+
"""Helper functions for SpiceDB client creation."""
22

3-
import grpc
43
from threading import Lock
54
from typing import Optional
65

6+
from authzed.api.v1 import InsecureClient
77

8-
class BearerTokenInterceptor(grpc.UnaryUnaryClientInterceptor, grpc.UnaryStreamClientInterceptor):
9-
"""
10-
gRPC interceptor that adds bearer token to all requests.
11-
12-
This is for local development with SpiceDB's --grpc-no-tls flag.
13-
"""
14-
15-
def __init__(self, token: str):
16-
self._token = token
17-
18-
def _add_authorization(self, client_call_details):
19-
"""Add authorization metadata to the call."""
20-
metadata = []
21-
if client_call_details.metadata is not None:
22-
metadata = list(client_call_details.metadata)
23-
metadata.append(("authorization", f"Bearer {self._token}"))
24-
25-
return grpc._interceptor._ClientCallDetails(
26-
client_call_details.method,
27-
client_call_details.timeout,
28-
metadata,
29-
client_call_details.credentials,
30-
client_call_details.wait_for_ready,
31-
client_call_details.compression,
32-
)
33-
34-
def intercept_unary_unary(self, continuation, client_call_details, request):
35-
"""Intercept unary-unary calls."""
36-
new_details = self._add_authorization(client_call_details)
37-
return continuation(new_details, request)
38-
39-
def intercept_unary_stream(self, continuation, client_call_details, request):
40-
"""Intercept unary-stream calls."""
41-
new_details = self._add_authorization(client_call_details)
42-
return continuation(new_details, request)
43-
44-
45-
# Global singleton for SpiceDB client with thread-safe initialization
46-
_spicedb_client: Optional["Client"] = None
8+
_spicedb_client: Optional[InsecureClient] = None
479
_spicedb_lock = Lock()
4810

4911

50-
def create_insecure_spicedb_client(endpoint: str, token: str):
12+
def create_insecure_spicedb_client(endpoint: str, token: str) -> InsecureClient:
5113
"""
5214
Create a SpiceDB client for insecure connections (local development).
5315
54-
This is for SpiceDB running with --grpc-no-tls flag.
55-
56-
Args:
57-
endpoint: The SpiceDB endpoint (e.g., "localhost:50051")
58-
token: The bearer token (e.g., "devtoken")
59-
60-
Returns:
61-
authzed.api.v1.Client configured for insecure connection
16+
For SpiceDB running with --grpc-no-tls flag.
6217
"""
63-
from authzed.api.v1 import Client
64-
65-
# Create insecure channel with bearer token interceptor
66-
channel = grpc.insecure_channel(endpoint)
67-
interceptor = BearerTokenInterceptor(token)
68-
intercepted_channel = grpc.intercept_channel(channel, interceptor)
69-
70-
# Create client bypassing __init__ and initialize with our channel
71-
client = Client.__new__(Client)
72-
client.init_stubs(intercepted_channel)
18+
return InsecureClient(endpoint, token)
7319

74-
return client
7520

76-
77-
def get_spicedb_client(endpoint: str, token: str):
21+
def get_spicedb_client(endpoint: str, token: str) -> InsecureClient:
7822
"""
7923
Get or create reusable SpiceDB client (singleton, thread-safe).
80-
81-
This function provides connection pooling for SpiceDB by maintaining
82-
a single client instance across requests, eliminating connection overhead.
83-
84-
Args:
85-
endpoint: The SpiceDB endpoint (e.g., "localhost:50051")
86-
token: The bearer token (e.g., "devtoken")
87-
88-
Returns:
89-
authzed.api.v1.Client configured for insecure connection
9024
"""
91-
from authzed.api.v1 import Client
92-
9325
global _spicedb_client
9426

95-
# Fast path: client already exists
9627
if _spicedb_client is not None:
9728
return _spicedb_client
9829

99-
# Slow path: create new client with thread-safe lock
10030
with _spicedb_lock:
101-
# Double-check after acquiring lock
10231
if _spicedb_client is None:
10332
_spicedb_client = create_insecure_spicedb_client(endpoint, token)
10433

10534
return _spicedb_client
10635

10736

10837
def reset_spicedb_client():
109-
"""
110-
Reset singleton (useful for testing).
111-
112-
This allows tests to clear the cached client and create a fresh one.
113-
"""
38+
"""Reset singleton (useful for testing)."""
11439
global _spicedb_client
11540
with _spicedb_lock:
11641
_spicedb_client = None
117-
118-
119-
# Backward compatibility - keep the old function name
120-
def insecure_bearer_token_credentials(token: str):
121-
"""
122-
Deprecated: Use create_insecure_spicedb_client instead.
123-
124-
This function is kept for backward compatibility but doesn't work
125-
with authzed Client for insecure connections.
126-
"""
127-
raise NotImplementedError(
128-
"For insecure SpiceDB connections, use create_insecure_spicedb_client() instead"
129-
)
Lines changed: 48 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,65 +1,70 @@
11
"""Authorization node - deterministic permission filtering via SpiceDB."""
22

33
from langchain_core.messages import SystemMessage
4+
from langchain_spicedb.core import SpiceDBAuthorizer
45

56
from ..state import AgenticRAGState
67
from ..config import get_config
7-
from ..grpc_helpers import get_spicedb_client
88
from ..logging_config import get_logger
9-
from ..authorization_helpers import batch_check_permissions
10-
from ..node_helpers import log_node_execution
119

1210
logger = get_logger("nodes.authorization")
1311

12+
_authorizer: SpiceDBAuthorizer | None = None
1413

15-
def authorization_node(state: AgenticRAGState) -> dict:
14+
15+
def _get_authorizer() -> SpiceDBAuthorizer:
16+
global _authorizer
17+
if _authorizer is None:
18+
config = get_config()
19+
_authorizer = SpiceDBAuthorizer(
20+
spicedb_endpoint=config.spicedb_endpoint,
21+
spicedb_token=config.spicedb_token,
22+
resource_type="document",
23+
subject_type="user",
24+
permission="view",
25+
resource_id_key="doc_id",
26+
)
27+
return _authorizer
28+
29+
30+
async def authorization_node(state: AgenticRAGState) -> dict:
1631
"""
1732
Deterministic authorization node - ALWAYS runs, cannot be bypassed.
1833
19-
This node filters retrieved documents based on SpiceDB permissions.
34+
Filters retrieved documents through SpiceDB's CheckBulkPermissions API.
2035
This is a security boundary - the agent cannot bypass this check.
2136
"""
22-
config = get_config()
37+
authorizer = _get_authorizer()
2338

24-
with log_node_execution(
25-
logger,
26-
"authorization",
27-
{
39+
logger.info(
40+
"Starting authorization",
41+
extra={
2842
"subject_id": state["subject_id"],
2943
"document_count": len(state["retrieved_documents"]),
30-
}
31-
):
32-
# Get or create SpiceDB client (reused across requests)
33-
client = get_spicedb_client(
34-
config.spicedb_endpoint,
35-
config.spicedb_token,
36-
)
37-
38-
# Batch check permissions using SpiceDB's bulk API
39-
authorized_docs, denied_doc_ids = batch_check_permissions(
40-
client,
41-
state["subject_id"],
42-
state["retrieved_documents"],
43-
)
44+
},
45+
)
4446

45-
denied_count = len(denied_doc_ids)
47+
result = await authorizer.filter_documents(
48+
documents=state["retrieved_documents"],
49+
subject_id=state["subject_id"],
50+
)
4651

47-
logger.info(
48-
"Authorization results",
49-
extra={
50-
"authorized": len(authorized_docs),
51-
"denied": denied_count,
52-
"denied_doc_ids": denied_doc_ids,
53-
},
54-
)
52+
logger.info(
53+
"Authorization results",
54+
extra={
55+
"authorized": result.total_authorized,
56+
"denied": len(result.denied_resource_ids),
57+
"denied_doc_ids": result.denied_resource_ids,
58+
},
59+
)
5560

56-
return {
57-
"authorized_documents": authorized_docs,
58-
"denied_count": denied_count,
59-
"authorization_passed": len(authorized_docs) > 0,
60-
"messages": [
61-
SystemMessage(
62-
content=f"Authorization: {len(authorized_docs)}/{len(state['retrieved_documents'])} documents authorized"
63-
)
64-
],
65-
}
61+
return {
62+
"authorized_documents": result.authorized_documents,
63+
"denied_count": len(result.denied_resource_ids),
64+
"authorization_passed": result.total_authorized > 0,
65+
"messages": [
66+
SystemMessage(
67+
content=f"Authorization: {result.total_authorized}/{result.total_retrieved} documents authorized"
68+
)
69+
],
70+
}

agentic-rag-authorization/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
langchain>=0.1.0
33
langchain-openai>=0.1.0
44
langgraph>=0.0.20
5+
langchain-spicedb>=0.2.0
56
weaviate-client>=3.26.0,<4.0 # v3 for REST API stability (no gRPC issues)
67
authzed>=0.7.0
78
python-dotenv>=1.0.0

0 commit comments

Comments
 (0)