|
13 | 13 | from hotdata.api.query_api import QueryApi |
14 | 14 | from hotdata.api.query_runs_api import QueryRunsApi |
15 | 15 | from hotdata.api.results_api import ResultsApi |
| 16 | +from hotdata.api.uploads_api import UploadsApi |
16 | 17 | from hotdata.exceptions import ApiException |
17 | 18 | from hotdata.models.async_query_response import AsyncQueryResponse |
18 | 19 | from hotdata.models.query_request import QueryRequest |
19 | 20 | from hotdata.models.query_response import QueryResponse |
| 21 | +from hotdata.models.load_managed_table_request import LoadManagedTableRequest |
20 | 22 | from hotdata.models.table_info import TableInfo |
21 | 23 |
|
22 | 24 | from hotdata_runtime.env import ( |
|
26 | 28 | normalize_host, |
27 | 29 | pick_workspace, |
28 | 30 | ) |
| 31 | +from hotdata_runtime.databases import ( |
| 32 | + DEFAULT_SCHEMA, |
| 33 | + LoadManagedTableResult, |
| 34 | + ManagedDatabase, |
| 35 | + ManagedTable, |
| 36 | + MANAGED_SOURCE_TYPE, |
| 37 | + api_error_message, |
| 38 | + create_connection_request, |
| 39 | + is_parquet_path, |
| 40 | + managed_database_from_connection, |
| 41 | +) |
29 | 42 | from hotdata_runtime.http import default_http_retries |
30 | 43 | from hotdata_runtime.result import QueryResult |
31 | 44 |
|
@@ -135,6 +148,144 @@ def query_runs(self) -> QueryRunsApi: |
135 | 148 | def results(self) -> ResultsApi: |
136 | 149 | return self._results_api() |
137 | 150 |
|
| 151 | + def uploads(self) -> UploadsApi: |
| 152 | + return UploadsApi(self._api) |
| 153 | + |
| 154 | + def list_managed_databases(self) -> list[ManagedDatabase]: |
| 155 | + listing = self.connections().list_connections() |
| 156 | + return [ |
| 157 | + managed_database_from_connection(c) |
| 158 | + for c in listing.connections |
| 159 | + if c.source_type == MANAGED_SOURCE_TYPE |
| 160 | + ] |
| 161 | + |
| 162 | + def resolve_managed_database(self, name_or_id: str) -> ManagedDatabase: |
| 163 | + listing = self.connections().list_connections() |
| 164 | + match = None |
| 165 | + for c in listing.connections: |
| 166 | + if c.id == name_or_id or c.name == name_or_id: |
| 167 | + match = c |
| 168 | + break |
| 169 | + if match is None: |
| 170 | + raise KeyError(f"No database named or with id {name_or_id!r}") |
| 171 | + if match.source_type != MANAGED_SOURCE_TYPE: |
| 172 | + raise ValueError( |
| 173 | + f"{match.name!r} is not a managed database " |
| 174 | + f"(source_type: {match.source_type})" |
| 175 | + ) |
| 176 | + return managed_database_from_connection(match) |
| 177 | + |
| 178 | + def create_managed_database( |
| 179 | + self, |
| 180 | + name: str, |
| 181 | + *, |
| 182 | + schema: str = DEFAULT_SCHEMA, |
| 183 | + tables: list[str] | None = None, |
| 184 | + ) -> ManagedDatabase: |
| 185 | + request = create_connection_request(name, schema=schema, tables=tables) |
| 186 | + try: |
| 187 | + created = self.connections().create_connection(request) |
| 188 | + except ApiException as e: |
| 189 | + raise RuntimeError(api_error_message(e)) from e |
| 190 | + return managed_database_from_connection(created) |
| 191 | + |
| 192 | + def delete_managed_database(self, name_or_id: str) -> None: |
| 193 | + db = self.resolve_managed_database(name_or_id) |
| 194 | + try: |
| 195 | + self.connections().delete_connection(db.id) |
| 196 | + except ApiException as e: |
| 197 | + raise RuntimeError(api_error_message(e)) from e |
| 198 | + |
| 199 | + def list_managed_tables( |
| 200 | + self, |
| 201 | + database: str, |
| 202 | + *, |
| 203 | + schema: str | None = None, |
| 204 | + ) -> list[ManagedTable]: |
| 205 | + db = self.resolve_managed_database(database) |
| 206 | + rows: list[ManagedTable] = [] |
| 207 | + for t in self.iter_tables(connection_id=db.id): |
| 208 | + if schema is not None and t.var_schema != schema: |
| 209 | + continue |
| 210 | + rows.append( |
| 211 | + ManagedTable( |
| 212 | + full_name=f"{db.name}.{t.var_schema}.{t.table}", |
| 213 | + schema=t.var_schema, |
| 214 | + table=t.table, |
| 215 | + synced=t.synced, |
| 216 | + last_sync=t.last_sync, |
| 217 | + ) |
| 218 | + ) |
| 219 | + rows.sort(key=lambda row: (row.schema, row.table)) |
| 220 | + return rows |
| 221 | + |
| 222 | + def upload_parquet(self, path: str) -> str: |
| 223 | + if not is_parquet_path(path): |
| 224 | + raise ValueError( |
| 225 | + f"Managed table loads require a parquet file (got {path!r})" |
| 226 | + ) |
| 227 | + with open(path, "rb") as f: |
| 228 | + data = f.read() |
| 229 | + try: |
| 230 | + uploaded = self.uploads().upload_file( |
| 231 | + data, |
| 232 | + _content_type="application/octet-stream", |
| 233 | + ) |
| 234 | + except ApiException as e: |
| 235 | + raise RuntimeError(api_error_message(e)) from e |
| 236 | + return uploaded.id |
| 237 | + |
| 238 | + def load_managed_table( |
| 239 | + self, |
| 240 | + database: str, |
| 241 | + table: str, |
| 242 | + *, |
| 243 | + schema: str = DEFAULT_SCHEMA, |
| 244 | + upload_id: str | None = None, |
| 245 | + file: str | None = None, |
| 246 | + ) -> LoadManagedTableResult: |
| 247 | + if (upload_id is None) == (file is None): |
| 248 | + raise ValueError("Exactly one of upload_id or file is required") |
| 249 | + db = self.resolve_managed_database(database) |
| 250 | + if upload_id is not None: |
| 251 | + resolved_upload_id = upload_id |
| 252 | + else: |
| 253 | + assert file is not None |
| 254 | + resolved_upload_id = self.upload_parquet(file) |
| 255 | + request = LoadManagedTableRequest( |
| 256 | + mode="replace", |
| 257 | + upload_id=resolved_upload_id, |
| 258 | + ) |
| 259 | + try: |
| 260 | + loaded = self.connections().load_managed_table( |
| 261 | + db.id, |
| 262 | + schema, |
| 263 | + table, |
| 264 | + request, |
| 265 | + ) |
| 266 | + except ApiException as e: |
| 267 | + raise RuntimeError(api_error_message(e)) from e |
| 268 | + return LoadManagedTableResult( |
| 269 | + connection_id=loaded.connection_id, |
| 270 | + schema_name=loaded.schema_name, |
| 271 | + table_name=loaded.table_name, |
| 272 | + row_count=loaded.row_count, |
| 273 | + full_name=f"{db.name}.{loaded.schema_name}.{loaded.table_name}", |
| 274 | + ) |
| 275 | + |
| 276 | + def delete_managed_table( |
| 277 | + self, |
| 278 | + database: str, |
| 279 | + table: str, |
| 280 | + *, |
| 281 | + schema: str = DEFAULT_SCHEMA, |
| 282 | + ) -> None: |
| 283 | + db = self.resolve_managed_database(database) |
| 284 | + try: |
| 285 | + self.connections().delete_managed_table(db.id, schema, table) |
| 286 | + except ApiException as e: |
| 287 | + raise RuntimeError(api_error_message(e)) from e |
| 288 | + |
138 | 289 | def list_recent_results( |
139 | 290 | self, |
140 | 291 | *, |
|
0 commit comments