Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 8 additions & 21 deletions src/xe_forge/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,27 +22,14 @@


def _setup_dspy(config: Config) -> None:
"""Configure DSPy LM from the shared config. Used by all LLM-driven paths."""
import dspy
import httpx
import litellm

if config.llm.api_base:
os.environ["OPENAI_API_BASE"] = config.llm.api_base
if config.llm.api_key:
os.environ["OPENAI_API_KEY"] = config.llm.api_key

litellm.client_session = httpx.Client(verify=False)
lm = dspy.LM(
model=config.llm.model,
api_base=config.llm.api_base,
model_type="responses",
api_key=config.llm.api_key or "",
temperature=config.llm.temperature,
max_tokens=config.llm.max_tokens,
cache=False,
)
dspy.configure(lm=lm, warn_on_type_mismatch=False)
"""Configure DSPy LM from the shared config. Used by all LLM-driven paths.

Supports OpenAI-compatible endpoints and AWS Bedrock (model ids prefixed
with ``bedrock/``). See :mod:`xe_forge.llm_setup`.
"""
from xe_forge.llm_setup import configure_dspy

configure_dspy(config.llm)


def _parse_args():
Expand Down
66 changes: 66 additions & 0 deletions src/xe_forge/llm_setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""Shared DSPy / LiteLLM configuration for all LLM-driven paths.

Supports OpenAI-compatible endpoints (the default) and AWS Bedrock. Bedrock is
selected automatically when the configured model id is prefixed with
``bedrock/`` (e.g. ``bedrock/us.anthropic.claude-sonnet-4-6``).

No credentials are read, logged, or stored here. For Bedrock, auth is handled
entirely by LiteLLM via the standard AWS environment variables
(``AWS_BEARER_TOKEN_BEDROCK`` or ``AWS_ACCESS_KEY_ID``/``AWS_SECRET_ACCESS_KEY``),
and the region is resolved from ``AWS_REGION_NAME``/``AWS_REGION``.
"""

from __future__ import annotations

import os


def is_bedrock(model: str) -> bool:
"""Return True if ``model`` targets AWS Bedrock."""
return model.startswith("bedrock/")


def configure_dspy(llm_config) -> None:
"""Build the DSPy LM from an LLM config section and register it globally.

The OpenAI-compatible path is unchanged from the original inline setup.
Bedrock uses the Converse (chat) API instead of the OpenAI Responses API
and does not take an OpenAI ``api_base``; its region comes from the
environment (default ``us-east-1``).
"""
import dspy
import httpx
import litellm

# TLS may terminate at a MITM proxy; keep verification disabled (as before)
# and mirror it onto the async client used by Bedrock/Converse calls.
litellm.ssl_verify = False
litellm.client_session = httpx.Client(verify=False)
litellm.aclient_session = httpx.AsyncClient(verify=False)

if is_bedrock(llm_config.model):
region = os.environ.get("AWS_REGION_NAME") or os.environ.get("AWS_REGION") or "us-east-1"
lm = dspy.LM(
model=llm_config.model,
model_type="chat",
temperature=llm_config.temperature,
max_tokens=llm_config.max_tokens,
cache=False,
aws_region_name=region,
)
else:
if llm_config.api_base:
os.environ["OPENAI_API_BASE"] = llm_config.api_base
if llm_config.api_key:
os.environ["OPENAI_API_KEY"] = llm_config.api_key
lm = dspy.LM(
model=llm_config.model,
api_base=llm_config.api_base,
model_type="responses",
api_key=llm_config.api_key or "",
temperature=llm_config.temperature,
max_tokens=llm_config.max_tokens,
cache=False,
)

dspy.configure(lm=lm, warn_on_type_mismatch=False)
23 changes: 3 additions & 20 deletions src/xe_forge/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
import logging
import os
from datetime import datetime
from pathlib import Path

import dspy
import httpx
import litellm

from xe_forge.agents import AnalyzerAgent, Optimizer, OptimizerAgent, OptimizerReActAgent
from xe_forge.config import Config, get_config
from xe_forge.core.device_query import get_device_config_for_pipeline
Expand Down Expand Up @@ -129,22 +124,10 @@ def _setup_logging(self):
Path(self.config.logging.kernel_dir).mkdir(parents=True, exist_ok=True)

def _setup_llm(self):
if self.config.llm.api_base:
os.environ["OPENAI_API_BASE"] = self.config.llm.api_base
if self.config.llm.api_key:
os.environ["OPENAI_API_KEY"] = self.config.llm.api_key
from xe_forge.llm_setup import configure_dspy

try:
litellm.client_session = httpx.Client(verify=False)
lm = dspy.LM(
model=self.config.llm.model,
api_base=self.config.llm.api_base,
model_type="responses",
api_key=self.config.llm.api_key or "",
temperature=self.config.llm.temperature,
max_tokens=self.config.llm.max_tokens,
cache=False,
)
dspy.configure(lm=lm, warn_on_type_mismatch=False)
configure_dspy(self.config.llm)
except Exception as e:
raise RuntimeError(f"Failed to initialize LLM: {e}") from e

Expand Down