|
1 | | -"""Helper functions for gRPC and SpiceDB authentication.""" |
| 1 | +"""Helper functions for SpiceDB client creation.""" |
2 | 2 |
|
3 | | -import grpc |
4 | 3 | from threading import Lock |
5 | 4 | from typing import Optional |
6 | 5 |
|
| 6 | +from authzed.api.v1 import InsecureClient |
7 | 7 |
|
8 | | -class BearerTokenInterceptor(grpc.UnaryUnaryClientInterceptor, grpc.UnaryStreamClientInterceptor): |
9 | | - """ |
10 | | - gRPC interceptor that adds bearer token to all requests. |
11 | | -
|
12 | | - This is for local development with SpiceDB's --grpc-no-tls flag. |
13 | | - """ |
14 | | - |
15 | | - def __init__(self, token: str): |
16 | | - self._token = token |
17 | | - |
18 | | - def _add_authorization(self, client_call_details): |
19 | | - """Add authorization metadata to the call.""" |
20 | | - metadata = [] |
21 | | - if client_call_details.metadata is not None: |
22 | | - metadata = list(client_call_details.metadata) |
23 | | - metadata.append(("authorization", f"Bearer {self._token}")) |
24 | | - |
25 | | - return grpc._interceptor._ClientCallDetails( |
26 | | - client_call_details.method, |
27 | | - client_call_details.timeout, |
28 | | - metadata, |
29 | | - client_call_details.credentials, |
30 | | - client_call_details.wait_for_ready, |
31 | | - client_call_details.compression, |
32 | | - ) |
33 | | - |
34 | | - def intercept_unary_unary(self, continuation, client_call_details, request): |
35 | | - """Intercept unary-unary calls.""" |
36 | | - new_details = self._add_authorization(client_call_details) |
37 | | - return continuation(new_details, request) |
38 | | - |
39 | | - def intercept_unary_stream(self, continuation, client_call_details, request): |
40 | | - """Intercept unary-stream calls.""" |
41 | | - new_details = self._add_authorization(client_call_details) |
42 | | - return continuation(new_details, request) |
43 | | - |
44 | | - |
45 | | -# Global singleton for SpiceDB client with thread-safe initialization |
46 | | -_spicedb_client: Optional["Client"] = None |
| 8 | +_spicedb_client: Optional[InsecureClient] = None |
47 | 9 | _spicedb_lock = Lock() |
48 | 10 |
|
49 | 11 |
|
50 | | -def create_insecure_spicedb_client(endpoint: str, token: str): |
| 12 | +def create_insecure_spicedb_client(endpoint: str, token: str) -> InsecureClient: |
51 | 13 | """ |
52 | 14 | Create a SpiceDB client for insecure connections (local development). |
53 | 15 |
|
54 | | - This is for SpiceDB running with --grpc-no-tls flag. |
55 | | -
|
56 | | - Args: |
57 | | - endpoint: The SpiceDB endpoint (e.g., "localhost:50051") |
58 | | - token: The bearer token (e.g., "devtoken") |
59 | | -
|
60 | | - Returns: |
61 | | - authzed.api.v1.Client configured for insecure connection |
| 16 | + For SpiceDB running with --grpc-no-tls flag. |
62 | 17 | """ |
63 | | - from authzed.api.v1 import Client |
64 | | - |
65 | | - # Create insecure channel with bearer token interceptor |
66 | | - channel = grpc.insecure_channel(endpoint) |
67 | | - interceptor = BearerTokenInterceptor(token) |
68 | | - intercepted_channel = grpc.intercept_channel(channel, interceptor) |
69 | | - |
70 | | - # Create client bypassing __init__ and initialize with our channel |
71 | | - client = Client.__new__(Client) |
72 | | - client.init_stubs(intercepted_channel) |
| 18 | + return InsecureClient(endpoint, token) |
73 | 19 |
|
74 | | - return client |
75 | 20 |
|
76 | | - |
77 | | -def get_spicedb_client(endpoint: str, token: str): |
| 21 | +def get_spicedb_client(endpoint: str, token: str) -> InsecureClient: |
78 | 22 | """ |
79 | 23 | Get or create reusable SpiceDB client (singleton, thread-safe). |
80 | | -
|
81 | | - This function provides connection pooling for SpiceDB by maintaining |
82 | | - a single client instance across requests, eliminating connection overhead. |
83 | | -
|
84 | | - Args: |
85 | | - endpoint: The SpiceDB endpoint (e.g., "localhost:50051") |
86 | | - token: The bearer token (e.g., "devtoken") |
87 | | -
|
88 | | - Returns: |
89 | | - authzed.api.v1.Client configured for insecure connection |
90 | 24 | """ |
91 | | - from authzed.api.v1 import Client |
92 | | - |
93 | 25 | global _spicedb_client |
94 | 26 |
|
95 | | - # Fast path: client already exists |
96 | 27 | if _spicedb_client is not None: |
97 | 28 | return _spicedb_client |
98 | 29 |
|
99 | | - # Slow path: create new client with thread-safe lock |
100 | 30 | with _spicedb_lock: |
101 | | - # Double-check after acquiring lock |
102 | 31 | if _spicedb_client is None: |
103 | 32 | _spicedb_client = create_insecure_spicedb_client(endpoint, token) |
104 | 33 |
|
105 | 34 | return _spicedb_client |
106 | 35 |
|
107 | 36 |
|
108 | 37 | def reset_spicedb_client(): |
109 | | - """ |
110 | | - Reset singleton (useful for testing). |
111 | | -
|
112 | | - This allows tests to clear the cached client and create a fresh one. |
113 | | - """ |
| 38 | + """Reset singleton (useful for testing).""" |
114 | 39 | global _spicedb_client |
115 | 40 | with _spicedb_lock: |
116 | 41 | _spicedb_client = None |
117 | | - |
118 | | - |
119 | | -# Backward compatibility - keep the old function name |
120 | | -def insecure_bearer_token_credentials(token: str): |
121 | | - """ |
122 | | - Deprecated: Use create_insecure_spicedb_client instead. |
123 | | -
|
124 | | - This function is kept for backward compatibility but doesn't work |
125 | | - with authzed Client for insecure connections. |
126 | | - """ |
127 | | - raise NotImplementedError( |
128 | | - "For insecure SpiceDB connections, use create_insecure_spicedb_client() instead" |
129 | | - ) |
0 commit comments