Skip to content

Commit 2d3d410

Browse files
committed
fix: address PR review nits for managed database helpers
Rename cross-module helpers, remove unused imports, simplify parquet detection, and clarify upload_id/file handling in load_managed_table.
1 parent a370abb commit 2d3d410

3 files changed

Lines changed: 18 additions & 18 deletions

File tree

hotdata_runtime/client.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,10 @@
3434
ManagedDatabase,
3535
ManagedTable,
3636
MANAGED_SOURCE_TYPE,
37-
_api_error,
38-
_managed_database,
37+
api_error_message,
3938
create_connection_request,
4039
is_parquet_path,
40+
managed_database_from_connection,
4141
)
4242
from hotdata_runtime.http import default_http_retries
4343
from hotdata_runtime.result import QueryResult
@@ -154,7 +154,7 @@ def uploads(self) -> UploadsApi:
154154
def list_managed_databases(self) -> list[ManagedDatabase]:
155155
listing = self.connections().list_connections()
156156
return [
157-
_managed_database(c)
157+
managed_database_from_connection(c)
158158
for c in listing.connections
159159
if c.source_type == MANAGED_SOURCE_TYPE
160160
]
@@ -173,7 +173,7 @@ def resolve_managed_database(self, name_or_id: str) -> ManagedDatabase:
173173
f"{match.name!r} is not a managed database "
174174
f"(source_type: {match.source_type})"
175175
)
176-
return _managed_database(match)
176+
return managed_database_from_connection(match)
177177

178178
def create_managed_database(
179179
self,
@@ -186,15 +186,15 @@ def create_managed_database(
186186
try:
187187
created = self.connections().create_connection(request)
188188
except ApiException as e:
189-
raise RuntimeError(_api_error(e)) from e
190-
return _managed_database(created)
189+
raise RuntimeError(api_error_message(e)) from e
190+
return managed_database_from_connection(created)
191191

192192
def delete_managed_database(self, name_or_id: str) -> None:
193193
db = self.resolve_managed_database(name_or_id)
194194
try:
195195
self.connections().delete_connection(db.id)
196196
except ApiException as e:
197-
raise RuntimeError(_api_error(e)) from e
197+
raise RuntimeError(api_error_message(e)) from e
198198

199199
def list_managed_tables(
200200
self,
@@ -232,7 +232,7 @@ def upload_parquet(self, path: str) -> str:
232232
_content_type="application/octet-stream",
233233
)
234234
except ApiException as e:
235-
raise RuntimeError(_api_error(e)) from e
235+
raise RuntimeError(api_error_message(e)) from e
236236
return uploaded.id
237237

238238
def load_managed_table(
@@ -247,7 +247,11 @@ def load_managed_table(
247247
if (upload_id is None) == (file is None):
248248
raise ValueError("Exactly one of upload_id or file is required")
249249
db = self.resolve_managed_database(database)
250-
resolved_upload_id = upload_id or self.upload_parquet(file or "")
250+
if upload_id is not None:
251+
resolved_upload_id = upload_id
252+
else:
253+
assert file is not None
254+
resolved_upload_id = self.upload_parquet(file)
251255
request = LoadManagedTableRequest(
252256
mode="replace",
253257
upload_id=resolved_upload_id,
@@ -260,7 +264,7 @@ def load_managed_table(
260264
request,
261265
)
262266
except ApiException as e:
263-
raise RuntimeError(_api_error(e)) from e
267+
raise RuntimeError(api_error_message(e)) from e
264268
return LoadManagedTableResult(
265269
connection_id=loaded.connection_id,
266270
schema_name=loaded.schema_name,
@@ -280,7 +284,7 @@ def delete_managed_table(
280284
try:
281285
self.connections().delete_managed_table(db.id, schema, table)
282286
except ApiException as e:
283-
raise RuntimeError(_api_error(e)) from e
287+
raise RuntimeError(api_error_message(e)) from e
284288

285289
def list_recent_results(
286290
self,

hotdata_runtime/databases.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
from hotdata.exceptions import ApiException
1010
from hotdata.models.create_connection_request import CreateConnectionRequest
11-
from hotdata.models.load_managed_table_request import LoadManagedTableRequest
1211

1312
MANAGED_SOURCE_TYPE = "managed"
1413
DEFAULT_SCHEMA = "public"
@@ -49,9 +48,6 @@ def to_dict(self) -> dict[str, Any]:
4948

5049

5150
def is_parquet_path(path: str) -> bool:
52-
lowered = path.lower()
53-
if lowered.endswith(".parquet"):
54-
return True
5551
return Path(path).suffix.lower() == ".parquet"
5652

5753

@@ -83,13 +79,13 @@ def create_connection_request(
8379
)
8480

8581

86-
def _managed_database(conn: Any) -> ManagedDatabase:
82+
def managed_database_from_connection(conn: Any) -> ManagedDatabase:
8783
return ManagedDatabase(
8884
id=str(conn.id),
8985
name=str(conn.name),
9086
source_type=str(conn.source_type),
9187
)
9288

9389

94-
def _api_error(exc: ApiException) -> str:
90+
def api_error_message(exc: ApiException) -> str:
9591
return exc.reason or str(exc)

tests/test_databases.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from types import SimpleNamespace
4-
from unittest.mock import MagicMock, mock_open, patch
4+
from unittest.mock import mock_open, patch
55

66
import pytest
77

0 commit comments

Comments
 (0)