Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
261 changes: 261 additions & 0 deletions src/ucode/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@
import shlex
import shutil
import subprocess
import time
from concurrent.futures import (
ThreadPoolExecutor,
as_completed,
)
from concurrent.futures import (
TimeoutError as FutureTimeoutError,
)
from pathlib import Path
from typing import Literal, cast, overload
from urllib import error as urllib_error
Expand Down Expand Up @@ -1251,6 +1259,259 @@ def build_mcp_service_url(workspace: str, full_name: str) -> str:
return f"{workspace}/ai-gateway/mcp-services/{full_name}"


# `list_vector_search_catalog_schemas` walks Vector Search endpoints+indexes.
# `list_uc_functions_catalog_schemas` walks UC catalogs+schemas in parallel and
# keeps only schemas with at least one user function.

_UC_LIST_PAGE_SIZE = 200
_UC_LIST_MAX_PAGES = 50
_UC_FUNCTION_PROBE_WORKERS = 16
_UC_LIST_HTTP_TIMEOUT = 10
_UC_FUNCTION_PROBE_TIMEOUT = 5
_VECTOR_SEARCH_DEADLINE_SECONDS = 15.0
_UC_FUNCTIONS_DEADLINE_SECONDS = 20.0
# Skip UC catalogs whose schemas almost never carry user-callable functions
# you'd want to expose as agent tools.
_UC_FUNCTIONS_SKIP_CATALOGS = frozenset(
{"__databricks_internal", "hive_metastore", "samples", "system"}
)


def _drain_with_deadline(futures: dict, deadline: float, on_result) -> None:
"""Iterate `futures` via `as_completed`, calling `on_result(value, key)` per
completed future, until either all are done or `deadline` passes. Per-task
exceptions are swallowed so one failure doesn't stop the rest."""
remaining = max(0.0, deadline - time.monotonic())
try:
for future in as_completed(futures, timeout=remaining):
try:
value = future.result()
except Exception: # noqa: BLE001
continue
on_result(value, futures[future])
if time.monotonic() > deadline:
break
except FutureTimeoutError:
pass


def _paginated_json_items(
base_url: str,
token: str,
*,
items_key: str,
extra_params: dict[str, str] | None = None,
page_size: int = _UC_LIST_PAGE_SIZE,
max_pages: int = _UC_LIST_MAX_PAGES,
timeout: int = 30,
) -> tuple[list[dict], str | None]:
"""Walk a Databricks `next_page_token` listing and return all items.

Returns (items, reason). Items are dicts; reason is None on success or a
short description of why the walk stopped early.
"""
items: list[dict] = []
page_token: str | None = None
seen_tokens: set[str] = set()
last_reason: str | None = None
for _ in range(max_pages):
params: dict[str, str] = {"max_results": str(page_size)}
if extra_params:
params.update(extra_params)
if page_token:
params["page_token"] = page_token
url = f"{base_url}?{urlencode(params)}"
payload, reason = _http_get_json(url, token, timeout=timeout)
if payload is None:
last_reason = reason
break
data = cast(dict, payload) if isinstance(payload, dict) else {}
raw = data.get(items_key) or []
if isinstance(raw, list):
for item in raw:
if isinstance(item, dict):
items.append(item)
page_token = data.get("next_page_token") or None
if not page_token or page_token in seen_tokens:
break
seen_tokens.add(page_token)
return items, last_reason


def _vector_index_catalog_schema(index: dict) -> tuple[str, str] | None:
"""Pull (catalog, schema) from one vector-search index entry."""
catalog = index.get("catalog_name")
schema = index.get("schema_name")
if isinstance(catalog, str) and isinstance(schema, str) and catalog and schema:
return catalog, schema
# Fallback: `name` is the fully-qualified UC name `catalog.schema.index`.
name = index.get("name")
if isinstance(name, str):
parts = name.split(".")
if len(parts) >= 3 and parts[0] and parts[1]:
return parts[0], parts[1]
return None


def list_vector_search_catalog_schemas(
workspace: str,
token: str,
*,
deadline_seconds: float = _VECTOR_SEARCH_DEADLINE_SECONDS,
) -> tuple[list[tuple[str, str]], str | None]:
"""Return sorted unique `(catalog, schema)` pairs that contain at least
one Databricks Vector Search index. Walks the per-endpoint index listings
in parallel under a wall-clock budget; returns partial results once
`deadline_seconds` is exceeded."""
hostname = workspace_hostname(workspace)
deadline = time.monotonic() + deadline_seconds
endpoints, reason = _paginated_json_items(
f"https://{hostname}/api/2.0/vector-search/endpoints",
token,
items_key="endpoints",
timeout=_UC_LIST_HTTP_TIMEOUT,
)
if not endpoints:
return [], reason or "no vector search endpoints found"

endpoint_names = [e["name"] for e in endpoints if isinstance(e.get("name"), str) and e["name"]]
if not endpoint_names:
return [], "no vector search endpoints with names"

pairs: set[tuple[str, str]] = set()
workers = max(1, min(_UC_FUNCTION_PROBE_WORKERS, len(endpoint_names)))
with ThreadPoolExecutor(max_workers=workers) as pool:
futures = {
pool.submit(
_paginated_json_items,
f"https://{hostname}/api/2.0/vector-search/indexes",
token,
items_key="vector_indexes",
extra_params={"endpoint_name": name},
timeout=_UC_LIST_HTTP_TIMEOUT,
): name
for name in endpoint_names
}

def collect(result, _endpoint):
indexes, _ = result
for index in indexes:
pair = _vector_index_catalog_schema(index)
if pair:
pairs.add(pair)

_drain_with_deadline(futures, deadline, collect)
pool.shutdown(wait=False, cancel_futures=True)

if not pairs:
return [], "no vector search indexes found"
return sorted(pairs), None


def _schema_has_user_function(hostname: str, token: str, catalog: str, schema: str) -> bool:
"""One-shot probe: does `{catalog}.{schema}` expose any UC function?"""
url = (
f"https://{hostname}/api/2.1/unity-catalog/functions"
f"?{urlencode({'catalog_name': catalog, 'schema_name': schema, 'max_results': '1'})}"
)
payload, _reason = _http_get_json(url, token, timeout=_UC_FUNCTION_PROBE_TIMEOUT)
if not isinstance(payload, dict):
return False
functions = payload.get("functions") or []
return isinstance(functions, list) and any(isinstance(item, dict) for item in functions)


def list_uc_functions_catalog_schemas(
workspace: str,
token: str,
*,
deadline_seconds: float = _UC_FUNCTIONS_DEADLINE_SECONDS,
) -> tuple[list[tuple[str, str]], str | None]:
"""Return sorted unique `(catalog, schema)` pairs containing at least one
user-defined UC function."""
hostname = workspace_hostname(workspace)
deadline = time.monotonic() + deadline_seconds

catalogs, catalogs_reason = _paginated_json_items(
f"https://{hostname}/api/2.1/unity-catalog/catalogs",
token,
items_key="catalogs",
timeout=_UC_LIST_HTTP_TIMEOUT,
)
if not catalogs:
return [], catalogs_reason or "no UC catalogs found"

catalog_names = [
c["name"]
for c in catalogs
if isinstance(c.get("name"), str)
and c["name"]
and c["name"] not in _UC_FUNCTIONS_SKIP_CATALOGS
]
if not catalog_names:
return [], "no user UC catalogs found"
if time.monotonic() > deadline:
return [], "deadline exceeded while listing UC catalogs"

# Parallel per-catalog schema listing.
candidate_pairs: list[tuple[str, str]] = []
schema_workers = max(1, min(_UC_FUNCTION_PROBE_WORKERS, len(catalog_names)))
with ThreadPoolExecutor(max_workers=schema_workers) as pool:
schema_futures = {
pool.submit(
_paginated_json_items,
f"https://{hostname}/api/2.1/unity-catalog/schemas",
token,
items_key="schemas",
extra_params={"catalog_name": cat},
timeout=_UC_LIST_HTTP_TIMEOUT,
): cat
for cat in catalog_names
}

def collect_schemas(result, catalog):
schemas, _ = result
for schema in schemas:
schema_name = schema.get("name")
# `information_schema` is auto-attached to every catalog and
# never holds user functions.
if (
isinstance(schema_name, str)
and schema_name
and schema_name != "information_schema"
):
candidate_pairs.append((catalog, schema_name))

_drain_with_deadline(schema_futures, deadline, collect_schemas)
pool.shutdown(wait=False, cancel_futures=True)

if not candidate_pairs:
if time.monotonic() > deadline:
return [], "deadline exceeded while listing UC schemas"
return [], "no UC schemas found"

# Parallel function-existence probes.
pairs: set[tuple[str, str]] = set()
with ThreadPoolExecutor(max_workers=_UC_FUNCTION_PROBE_WORKERS) as pool:
probe_futures = {
pool.submit(_schema_has_user_function, hostname, token, cat, schema): (cat, schema)
for cat, schema in candidate_pairs
}

def collect_pair(has_fn, pair):
if has_fn:
pairs.add(pair)

_drain_with_deadline(probe_futures, deadline, collect_pair)
pool.shutdown(wait=False, cancel_futures=True)

if not pairs:
if time.monotonic() > deadline:
return [], "deadline exceeded probing UC schemas for functions"
return [], "no UC schemas with user functions found"
return sorted(pairs), None


def discover_claude_models(workspace: str, token: str) -> tuple[dict[str, str], str | None]:
"""Discover Claude families on this workspace's AI Gateway.

Expand Down
Loading
Loading