Skip to content

Commit 5f00a87

Browse files
authored
feat: add support for experimental host Spanner endpoints (#62)
* compatibility with experimental host * added parameter validations * make use_plain_text as a toggle flag
1 parent b1c7873 commit 5f00a87

File tree

5 files changed

+126
-5
lines changed

5 files changed

+126
-5
lines changed

spanner_graphs/cloud_database.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,29 @@ def get_as_field_info_list(fields: List[StructType.Field]) -> List[SpannerFieldI
3939

4040
class CloudSpannerDatabase(SpannerDatabase):
4141
"""Concrete implementation for Spanner database on the cloud."""
42-
def __init__(self, project_id: str, instance_id: str,
43-
database_id: str) -> None:
44-
credentials, _ = _get_default_credentials_with_project()
45-
self.client = spanner.Client(
46-
project=project_id, credentials=credentials, client_options=ClientOptions(quota_project_id=project_id))
42+
def __init__(
43+
self,
44+
project_id: str,
45+
instance_id: str,
46+
database_id: str,
47+
experimental_host: str | None = None,
48+
use_plain_text: bool = False,
49+
ca_certificate: str | None = None,
50+
client_certificate: str | None = None,
51+
client_key: str | None = None,
52+
) -> None:
53+
if experimental_host:
54+
self.client = spanner.Client(
55+
use_plain_text=use_plain_text,
56+
experimental_host=experimental_host,
57+
ca_certificate=ca_certificate,
58+
client_certificate=client_certificate,
59+
client_key=client_key,
60+
)
61+
else:
62+
credentials, _ = _get_default_credentials_with_project()
63+
self.client = spanner.Client(
64+
project=project_id, credentials=credentials, client_options=ClientOptions(quota_project_id=project_id))
4765
self.instance = self.client.instance(instance_id)
4866
logger = logging.getLogger("spanner_graphs")
4967
logger.setLevel(logging.CRITICAL)

spanner_graphs/database.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ class SpannerEnv(Enum):
3232
CLOUD = auto()
3333
INFRA = auto()
3434
MOCK = auto()
35+
EXPERIMENTAL_HOST = auto()
36+
3537

3638
@dataclass
3739
class DatabaseSelector:
@@ -47,12 +49,22 @@ class DatabaseSelector:
4749
instance: The Spanner instance.
4850
database: The Spanner database.
4951
infra_db_path: The path for an internal infrastructure database.
52+
experimental_host: The Spanner experimental host endpoint.
53+
use_plain_text: Whether to use plain text for the experimental host endpoint.
54+
ca_certificate: CA certificate path for the experimental host endpoint.
55+
client_certificate: Client certificate path for the experimental host endpoint.
56+
client_key: Client key path for the experimental host endpoint.
5057
"""
5158
env: SpannerEnv
5259
project: str | None = None
5360
instance: str | None = None
5461
database: str | None = None
5562
infra_db_path: str | None = None
63+
experimental_host: str | None = None
64+
use_plain_text: bool = False
65+
ca_certificate: str | None = None
66+
client_certificate: str | None = None
67+
client_key: str | None = None
5668

5769
@classmethod
5870
def cloud(cls, project: str, instance: str, database: str) -> 'DatabaseSelector':
@@ -73,13 +85,36 @@ def mock(cls) -> 'DatabaseSelector':
7385
"""Creates a selector for a mock Spanner database."""
7486
return cls(env=SpannerEnv.MOCK)
7587

88+
@classmethod
89+
def experimental_host(
90+
cls, experimental_host: str, database: str, use_plain_text: bool = False, ca_certificate: str | None = None, client_certificate: str | None = None, client_key: str | None = None
91+
) -> "DatabaseSelector":
92+
"""Creates a selector for a Google Experimental Host Spanner database."""
93+
if not database:
94+
raise ValueError(
95+
"database is required for Experimental Host Spanner Endpoint"
96+
)
97+
return cls(
98+
env=SpannerEnv.EXPERIMENTAL_HOST,
99+
project="default",
100+
instance="default",
101+
database=database,
102+
experimental_host=experimental_host,
103+
use_plain_text=use_plain_text,
104+
ca_certificate=ca_certificate,
105+
client_certificate=client_certificate,
106+
client_key=client_key,
107+
)
108+
76109
def get_key(self) -> str:
77110
if self.env == SpannerEnv.CLOUD:
78111
return f"cloud_{self.project}_{self.instance}_{self.database}"
79112
elif self.env == SpannerEnv.INFRA:
80113
return f"infra_{self.infra_db_path}"
81114
elif self.env == SpannerEnv.MOCK:
82115
return "mock"
116+
elif self.env == SpannerEnv.EXPERIMENTAL_HOST:
117+
return f"experimental_host_{self.database}"
83118
else:
84119
raise ValueError("Unknown Spanner environment")
85120

spanner_graphs/exec_env.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,24 @@ def get_database_instance(
8181
raise RuntimeError(
8282
"Infra Spanner support is not available in this environment."
8383
)
84+
elif selector.env == SpannerEnv.EXPERIMENTAL_HOST:
85+
try:
86+
cloud_db_module = importlib.import_module("spanner_graphs.cloud_database")
87+
CloudSpannerDatabase = getattr(cloud_db_module, "CloudSpannerDatabase")
88+
db = CloudSpannerDatabase(
89+
selector.project,
90+
selector.instance,
91+
selector.database,
92+
selector.experimental_host,
93+
selector.use_plain_text,
94+
selector.ca_certificate,
95+
selector.client_certificate,
96+
selector.client_key,
97+
)
98+
except ImportError:
99+
raise RuntimeError(
100+
"Spanner experimental host support is not available in this environment."
101+
)
84102
else:
85103
raise ValueError(f"Unsupported Spanner environment: {selector.env}")
86104

spanner_graphs/graph_server.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,10 @@ def dict_to_selector(selector_dict: Dict[str, Any]) -> DatabaseSelector:
6565
return DatabaseSelector.infra(selector_dict['infra_db_path'])
6666
elif env == SpannerEnv.MOCK:
6767
return DatabaseSelector.mock()
68+
elif env == SpannerEnv.EXPERIMENTAL_HOST:
69+
return DatabaseSelector.experimental_host(
70+
selector_dict["experimental_host"], selector_dict["database"], selector_dict["use_plain_text"], selector_dict["ca_certificate"], selector_dict["client_certificate"], selector_dict["client_key"]
71+
)
6872
raise ValueError(f"Invalid env in selector dict: {selector_dict}")
6973
except Exception as e:
7074
print (f"Unexpected error when fetching selector: {e}")

spanner_graphs/magics.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,14 +197,59 @@ def spanner_graph(self, line: str, cell: str):
197197
parser.add_argument("--infra_db_path",
198198
action="store_true",
199199
help="Connect to internal Infra Spanner")
200+
parser.add_argument(
201+
"--experimental_host",
202+
type=str,
203+
required=False,
204+
help="Spanner experimental host endpoint",
205+
)
206+
parser.add_argument(
207+
"--use_plain_text",
208+
action="store_true",
209+
help="[Experimental Host Only] Use plain text communication for the experimental host",
210+
)
211+
parser.add_argument(
212+
"--ca_certificate",
213+
type=str,
214+
required=False,
215+
help="[Experimental Host Only] CA certificate path for the experimental host",
216+
)
217+
parser.add_argument(
218+
"--client_certificate",
219+
type=str,
220+
required=False,
221+
help="[Experimental Host Only] Client certificate path for the experimental host",
222+
)
223+
parser.add_argument(
224+
"--client_key",
225+
type=str,
226+
required=False,
227+
help="[Experimental Host Only] Client key path for the experimental host",
228+
)
200229

201230
try:
202231
args = parser.parse_args(line.split())
203232
selector = None
233+
if not args.experimental_host:
234+
if args.use_plain_text or args.ca_certificate or args.client_certificate or args.client_key:
235+
raise ValueError("use_plain_text, ca_certificate, client_certificate and client_key are only supported for Experimental Host")
204236
if args.mock:
205237
selector = DatabaseSelector.mock()
206238
elif args.infra_db_path:
207239
selector = DatabaseSelector.infra(infra_db_path=args.database)
240+
elif args.experimental_host:
241+
if args.use_plain_text:
242+
if args.ca_certificate or args.client_certificate or args.client_key:
243+
raise ValueError("When use_plain_text is true, no other certificate parameters should be set.")
244+
elif not args.ca_certificate:
245+
raise ValueError("Either use_plain_text must be true or ca_certificate must be set.")
246+
247+
if bool(args.client_certificate) != bool(args.client_key):
248+
raise ValueError("client_certificate and client_key must both be provided together.")
249+
250+
selector = DatabaseSelector.experimental_host(
251+
experimental_host=args.experimental_host, database=args.database, use_plain_text=args.use_plain_text, ca_certificate=args.ca_certificate, client_certificate=args.client_certificate, client_key=args.client_key
252+
)
208253
else:
209254
if not (args.project and args.instance):
210255
raise ValueError(
@@ -226,6 +271,7 @@ def spanner_graph(self, line: str, cell: str):
226271
print(f"Error: {e}")
227272
print(" %%spanner_graph --project <proj> --instance <inst> --database <db>")
228273
print(" %%spanner_graph --mock")
274+
print(" %%spanner_graph --experimental_host <host> --database <db> [--use_plain_text] [--ca_certificate <path>] [--client_certificate <path>] [--client_key <path>]")
229275
print(" Graph query here...")
230276

231277
def load_ipython_extension(ipython):

0 commit comments

Comments
 (0)