diff --git a/src/xe_forge/cli.py b/src/xe_forge/cli.py index b7aa7f6..042752d 100644 --- a/src/xe_forge/cli.py +++ b/src/xe_forge/cli.py @@ -22,27 +22,14 @@ def _setup_dspy(config: Config) -> None: - """Configure DSPy LM from the shared config. Used by all LLM-driven paths.""" - import dspy - import httpx - import litellm - - if config.llm.api_base: - os.environ["OPENAI_API_BASE"] = config.llm.api_base - if config.llm.api_key: - os.environ["OPENAI_API_KEY"] = config.llm.api_key - - litellm.client_session = httpx.Client(verify=False) - lm = dspy.LM( - model=config.llm.model, - api_base=config.llm.api_base, - model_type="responses", - api_key=config.llm.api_key or "", - temperature=config.llm.temperature, - max_tokens=config.llm.max_tokens, - cache=False, - ) - dspy.configure(lm=lm, warn_on_type_mismatch=False) + """Configure DSPy LM from the shared config. Used by all LLM-driven paths. + + Supports OpenAI-compatible endpoints and AWS Bedrock (model ids prefixed + with ``bedrock/``). See :mod:`xe_forge.llm_setup`. + """ + from xe_forge.llm_setup import configure_dspy + + configure_dspy(config.llm) def _parse_args(): diff --git a/src/xe_forge/llm_setup.py b/src/xe_forge/llm_setup.py new file mode 100644 index 0000000..a8df937 --- /dev/null +++ b/src/xe_forge/llm_setup.py @@ -0,0 +1,66 @@ +"""Shared DSPy / LiteLLM configuration for all LLM-driven paths. + +Supports OpenAI-compatible endpoints (the default) and AWS Bedrock. Bedrock is +selected automatically when the configured model id is prefixed with +``bedrock/`` (e.g. ``bedrock/us.anthropic.claude-sonnet-4-6``). + +No credentials are read, logged, or stored here. For Bedrock, auth is handled +entirely by LiteLLM via the standard AWS environment variables +(``AWS_BEARER_TOKEN_BEDROCK`` or ``AWS_ACCESS_KEY_ID``/``AWS_SECRET_ACCESS_KEY``), +and the region is resolved from ``AWS_REGION_NAME``/``AWS_REGION``. +""" + +from __future__ import annotations + +import os + + +def is_bedrock(model: str) -> bool: + """Return True if ``model`` targets AWS Bedrock.""" + return model.startswith("bedrock/") + + +def configure_dspy(llm_config) -> None: + """Build the DSPy LM from an LLM config section and register it globally. + + The OpenAI-compatible path is unchanged from the original inline setup. + Bedrock uses the Converse (chat) API instead of the OpenAI Responses API + and does not take an OpenAI ``api_base``; its region comes from the + environment (default ``us-east-1``). + """ + import dspy + import httpx + import litellm + + # TLS may terminate at a MITM proxy; keep verification disabled (as before) + # and mirror it onto the async client used by Bedrock/Converse calls. + litellm.ssl_verify = False + litellm.client_session = httpx.Client(verify=False) + litellm.aclient_session = httpx.AsyncClient(verify=False) + + if is_bedrock(llm_config.model): + region = os.environ.get("AWS_REGION_NAME") or os.environ.get("AWS_REGION") or "us-east-1" + lm = dspy.LM( + model=llm_config.model, + model_type="chat", + temperature=llm_config.temperature, + max_tokens=llm_config.max_tokens, + cache=False, + aws_region_name=region, + ) + else: + if llm_config.api_base: + os.environ["OPENAI_API_BASE"] = llm_config.api_base + if llm_config.api_key: + os.environ["OPENAI_API_KEY"] = llm_config.api_key + lm = dspy.LM( + model=llm_config.model, + api_base=llm_config.api_base, + model_type="responses", + api_key=llm_config.api_key or "", + temperature=llm_config.temperature, + max_tokens=llm_config.max_tokens, + cache=False, + ) + + dspy.configure(lm=lm, warn_on_type_mismatch=False) diff --git a/src/xe_forge/pipeline.py b/src/xe_forge/pipeline.py index 8e29caf..15acac7 100644 --- a/src/xe_forge/pipeline.py +++ b/src/xe_forge/pipeline.py @@ -1,12 +1,7 @@ import logging -import os from datetime import datetime from pathlib import Path -import dspy -import httpx -import litellm - from xe_forge.agents import AnalyzerAgent, Optimizer, OptimizerAgent, OptimizerReActAgent from xe_forge.config import Config, get_config from xe_forge.core.device_query import get_device_config_for_pipeline @@ -129,22 +124,10 @@ def _setup_logging(self): Path(self.config.logging.kernel_dir).mkdir(parents=True, exist_ok=True) def _setup_llm(self): - if self.config.llm.api_base: - os.environ["OPENAI_API_BASE"] = self.config.llm.api_base - if self.config.llm.api_key: - os.environ["OPENAI_API_KEY"] = self.config.llm.api_key + from xe_forge.llm_setup import configure_dspy + try: - litellm.client_session = httpx.Client(verify=False) - lm = dspy.LM( - model=self.config.llm.model, - api_base=self.config.llm.api_base, - model_type="responses", - api_key=self.config.llm.api_key or "", - temperature=self.config.llm.temperature, - max_tokens=self.config.llm.max_tokens, - cache=False, - ) - dspy.configure(lm=lm, warn_on_type_mismatch=False) + configure_dspy(self.config.llm) except Exception as e: raise RuntimeError(f"Failed to initialize LLM: {e}") from e