Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -195,3 +195,6 @@ $RECYCLE.BIN/
.pytest_cache/
software-agent-sdk/
snapshot_report.html

# Local developer overrides (never commit)
.local.env
161 changes: 137 additions & 24 deletions openhands_cli/stores/agent_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
AgentContext,
LLMSummarizingCondenser,
LocalFileStore,
create_llm,
)
from openhands.sdk.context import load_project_skills
from openhands.sdk.conversation.persistence_const import BASE_STATE
Expand Down Expand Up @@ -132,6 +133,8 @@ def get_default_critic(llm: LLM, *, enable_critic: bool = True) -> CriticBase |
ENV_LLM_API_KEY = "LLM_API_KEY"
ENV_LLM_BASE_URL = "LLM_BASE_URL"
ENV_LLM_MODEL = "LLM_MODEL"
ENV_DATABRICKS_HOST = "DATABRICKS_HOST"
ENV_DATABRICKS_TOKEN = "DATABRICKS_TOKEN"


class MissingEnvironmentVariablesError(Exception):
Expand Down Expand Up @@ -165,6 +168,10 @@ def check_and_warn_env_vars() -> None:
env_vars_set.append(ENV_LLM_BASE_URL)
if os.environ.get(ENV_LLM_MODEL):
env_vars_set.append(ENV_LLM_MODEL)
if os.environ.get(ENV_DATABRICKS_HOST):
env_vars_set.append(ENV_DATABRICKS_HOST)
if os.environ.get(ENV_DATABRICKS_TOKEN):
env_vars_set.append(ENV_DATABRICKS_TOKEN)

if env_vars_set:
console = Console(stderr=True)
Expand All @@ -184,13 +191,19 @@ class LLMEnvOverrides(BaseModel):
Environment variables take precedence over stored settings and are
NOT persisted to disk (temporary override only).

Databricks: ``DATABRICKS_HOST`` / ``DATABRICKS_TOKEN`` supplement ``LLM_*``.
M2M (``DATABRICKS_CLIENT_ID`` / ``DATABRICKS_CLIENT_SECRET``) is read only by
the SDK from the process environment, not through this model.

Use the `from_env()` class method to load values from environment
variables when env overrides are enabled.
"""

api_key: SecretStr | None = None
base_url: str | None = None
model: str | None = None
databricks_host: str | None = None
databricks_token: str | None = None

@classmethod
def from_env(cls, enabled: bool = False) -> LLMEnvOverrides:
Expand Down Expand Up @@ -221,10 +234,48 @@ def from_env(cls, enabled: bool = False) -> LLMEnvOverrides:
if model:
result["model"] = model

db_host = os.environ.get(ENV_DATABRICKS_HOST) or None
if db_host:
result["databricks_host"] = db_host.strip().rstrip("/")

db_tok = os.environ.get(ENV_DATABRICKS_TOKEN) or None
if db_tok:
result["databricks_token"] = db_tok

return cls(**result)

def to_llm_kwargs(self) -> dict[str, Any]:
"""Map overrides to ``create_llm()`` / ``LLM`` kwargs."""
kwargs: dict[str, Any] = {}
is_db = (self.model or "").startswith("databricks/")
if self.model:
kwargs["model"] = self.model
if self.api_key is not None:
kwargs["api_key"] = self.api_key
elif is_db and self.databricks_token:
kwargs["api_key"] = SecretStr(self.databricks_token)
merged_base = self.base_url or self.databricks_host
if merged_base:
kwargs["base_url"] = merged_base
if self.databricks_host:
kwargs["databricks_host"] = self.databricks_host
return kwargs

def require_for_headless(self) -> None:
missing: list[str] = []
"""Validate env overrides for headless agent creation."""
is_databricks = (self.model or "").startswith("databricks/")
has_host = bool(self.databricks_host or self.base_url)
if is_databricks:
missing: list[str] = []
if self.model is None:
missing.append(ENV_LLM_MODEL)
if not has_host:
missing.append(f"{ENV_DATABRICKS_HOST} or {ENV_LLM_BASE_URL}")
if missing:
raise MissingEnvironmentVariablesError(missing)
return

missing = []
if self.api_key is None:
missing.append(ENV_LLM_API_KEY)
if self.model is None:
Expand All @@ -234,12 +285,23 @@ def require_for_headless(self) -> None:

def has_overrides(self) -> bool:
"""Check if any overrides are set."""
return any([self.api_key, self.base_url, self.model])
return any(
[
self.api_key,
self.base_url,
self.model,
self.databricks_host,
self.databricks_token,
]
)


def apply_llm_overrides(llm: LLM, overrides: LLMEnvOverrides) -> LLM:
"""Apply environment variable overrides to an LLM instance.

Rebuilds via ``create_llm`` for Databricks native models so private client
state is not stale (``model_copy`` skips ``DatabricksLLM`` init).

Args:
llm: The LLM instance to update
overrides: LLMEnvOverrides instance from get_env_llm_overrides()
Expand All @@ -250,7 +312,46 @@ def apply_llm_overrides(llm: LLM, overrides: LLMEnvOverrides) -> LLM:
if not overrides.has_overrides():
return llm

return llm.model_copy(update=overrides.model_dump(exclude_none=True))
kw = overrides.to_llm_kwargs()
target_model = kw.get("model", llm.model)

is_databricks_model = isinstance(target_model, str) and target_model.startswith(
"databricks/"
)
is_databricks_instance = False
try:
from openhands.sdk.llm.providers.databricks.llm import DatabricksLLM

is_databricks_instance = isinstance(llm, DatabricksLLM)
except ImportError:
pass

if is_databricks_instance or is_databricks_model:
if is_databricks_instance:
# Close the existing httpx client before discarding the instance to
# prevent connection pool leaks (model_copy skips __init__).
try:
llm.close() # type: ignore[union-attr]
except Exception:
pass
base = llm.model_dump(exclude_none=True)
base.pop("provider", None)
for key, val in kw.items():
if val is not None:
base[key] = val
filtered = {k: v for k, v in base.items() if v is not None}
return create_llm(**filtered)

patch: dict[str, Any] = {}
if overrides.api_key is not None:
patch["api_key"] = overrides.api_key
if overrides.base_url is not None:
patch["base_url"] = overrides.base_url
if overrides.model is not None:
patch["model"] = overrides.model
if not patch:
return llm
return llm.model_copy(update=patch)


class AgentStore:
Expand Down Expand Up @@ -289,15 +390,9 @@ def _ensure_agent(self, agent: Agent | None, overrides: LLMEnvOverrides) -> Agen

# In env override mode, require enough info to create an agent.
overrides.require_for_headless()
assert overrides.api_key is not None
assert overrides.model is not None

llm = LLM(
model=overrides.model,
api_key=overrides.api_key.get_secret_value(),
base_url=overrides.base_url,
usage_id="agent",
)
llm_kwargs = overrides.to_llm_kwargs()
llm_kwargs["usage_id"] = "agent"
llm = create_llm(**llm_kwargs)
return get_default_cli_agent(llm)

def _apply_env_overrides(self, agent: Agent, overrides: LLMEnvOverrides) -> Agent:
Expand Down Expand Up @@ -493,19 +588,37 @@ def create_and_save_from_settings(
model = settings.get("llm_model", default_model)
base_url = settings.get("llm_base_url")

llm = LLM(
model=model,
api_key=llm_api_key,
base_url=base_url,
usage_id="agent",
)
if isinstance(model, str) and model.startswith("databricks/"):
from types import SimpleNamespace

condenser_llm = LLM(
model=model,
api_key=llm_api_key,
base_url=base_url,
usage_id="condenser",
)
from openhands.sdk.llm.providers.databricks.settings_bridge import (
kwargs_from_settings,
)

db_settings = SimpleNamespace(
model=model,
api_key=llm_api_key if llm_api_key else None,
base_url=base_url,
databricks_host=base_url if base_url else None,
)
llm = create_llm(**kwargs_from_settings(db_settings, usage_id="agent"))
condenser_llm = create_llm(
**kwargs_from_settings(db_settings, usage_id="condenser")
)
else:
llm = LLM(
model=model,
api_key=llm_api_key,
base_url=base_url,
usage_id="agent",
)

condenser_llm = LLM(
model=model,
api_key=llm_api_key,
base_url=base_url,
usage_id="condenser",
)

condenser = LLMSummarizingCondenser(llm=condenser_llm)

Expand Down
Loading