Skip to content

Commit e3c44bf

Browse files
committed
feat: add support for all models to agent config
1 parent 50b2c79 commit e3c44bf

File tree

8 files changed

+824
-5
lines changed

8 files changed

+824
-5
lines changed

src/strands/experimental/agent_config.py

Lines changed: 142 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,33 @@
99
agent = config_to_agent("config.json")
1010
# Add tools that need code-based instantiation
1111
agent.tool_registry.process_tools([ToolWithConfigArg(HttpsConnection("localhost"))])
12+
13+
The ``model`` field supports two formats:
14+
15+
**String format (backward compatible — defaults to Bedrock):**
16+
{"model": "us.anthropic.claude-sonnet-4-20250514-v1:0"}
17+
18+
**Object format (supports all providers):**
19+
{
20+
"model": {
21+
"provider": "anthropic",
22+
"model_id": "claude-sonnet-4-20250514",
23+
"max_tokens": 10000,
24+
"client_args": {"api_key": "$ANTHROPIC_API_KEY"}
25+
}
26+
}
27+
28+
Environment variable references (``$VAR`` or ``${VAR}``) in model config values are resolved
29+
automatically before provider instantiation.
30+
31+
Note: The following constructor parameters cannot be specified from JSON because they require
32+
code-based instantiation: ``boto_session`` (Bedrock, SageMaker), ``client`` (OpenAI, Gemini),
33+
``gemini_tools`` (Gemini). Use ``region_name`` / ``client_args`` as JSON-friendly alternatives.
1234
"""
1335

1436
import json
37+
import os
38+
import re
1539
from pathlib import Path
1640
from typing import Any
1741

@@ -27,8 +51,25 @@
2751
"properties": {
2852
"name": {"description": "Name of the agent", "type": ["string", "null"], "default": None},
2953
"model": {
30-
"description": "The model ID to use for this agent. If not specified, uses the default model.",
31-
"type": ["string", "null"],
54+
"description": (
55+
"The model to use for this agent. Can be a string (Bedrock model_id) "
56+
"or an object with a 'provider' field for any supported provider."
57+
),
58+
"oneOf": [
59+
{"type": "string"},
60+
{"type": "null"},
61+
{
62+
"type": "object",
63+
"properties": {
64+
"provider": {
65+
"description": "The model provider name",
66+
"type": "string",
67+
}
68+
},
69+
"required": ["provider"],
70+
"additionalProperties": True,
71+
},
72+
],
3273
"default": None,
3374
},
3475
"prompt": {
@@ -50,6 +91,87 @@
5091
# Pre-compile validator for better performance
5192
_VALIDATOR = jsonschema.Draft7Validator(AGENT_CONFIG_SCHEMA)
5293

94+
# Pattern for matching environment variable references
95+
_ENV_VAR_PATTERN = re.compile(r"^\$\{([^}]+)\}$|^\$([A-Za-z_][A-Za-z0-9_]*)$")
96+
97+
# Provider name to model class name — resolved via strands.models lazy __getattr__
98+
PROVIDER_MAP: dict[str, str] = {
99+
"bedrock": "BedrockModel",
100+
"anthropic": "AnthropicModel",
101+
"openai": "OpenAIModel",
102+
"gemini": "GeminiModel",
103+
"ollama": "OllamaModel",
104+
"litellm": "LiteLLMModel",
105+
"mistral": "MistralModel",
106+
"llamaapi": "LlamaAPIModel",
107+
"llamacpp": "LlamaCppModel",
108+
"sagemaker": "SageMakerAIModel",
109+
"writer": "WriterModel",
110+
"openai_responses": "OpenAIResponsesModel",
111+
}
112+
113+
114+
def _resolve_env_vars(value: Any) -> Any:
115+
"""Recursively resolve environment variable references in config values.
116+
117+
String values matching ``$VAR_NAME`` or ``${VAR_NAME}`` are replaced with the
118+
corresponding environment variable value. Dicts and lists are traversed recursively.
119+
120+
Args:
121+
value: The value to resolve. Can be a string, dict, list, or any other type.
122+
123+
Returns:
124+
The resolved value with environment variable references replaced.
125+
126+
Raises:
127+
ValueError: If a referenced environment variable is not set.
128+
"""
129+
if isinstance(value, str):
130+
match = _ENV_VAR_PATTERN.match(value)
131+
if match:
132+
var_name = match.group(1) or match.group(2)
133+
env_value = os.environ.get(var_name)
134+
if env_value is None:
135+
raise ValueError(f"Environment variable '{var_name}' is not set")
136+
return env_value
137+
return value
138+
if isinstance(value, dict):
139+
return {k: _resolve_env_vars(v) for k, v in value.items()}
140+
if isinstance(value, list):
141+
return [_resolve_env_vars(item) for item in value]
142+
return value
143+
144+
145+
def _create_model_from_dict(model_config: dict[str, Any]) -> Any:
146+
"""Create a Model instance from a provider config dict.
147+
148+
Routes the config to the appropriate model class based on the ``provider`` field,
149+
then delegates to the class's ``from_dict`` method. All imports are lazy to avoid
150+
requiring optional dependencies that are not installed.
151+
152+
Args:
153+
model_config: Dict containing at least a ``provider`` key and provider-specific params.
154+
155+
Returns:
156+
A configured Model instance for the specified provider.
157+
158+
Raises:
159+
ValueError: If the provider name is not recognized.
160+
ImportError: If the provider's optional dependencies are not installed.
161+
"""
162+
config = model_config.copy()
163+
provider = config.pop("provider")
164+
165+
class_name = PROVIDER_MAP.get(provider)
166+
if class_name is None:
167+
supported = ", ".join(sorted(PROVIDER_MAP.keys()))
168+
raise ValueError(f"Unknown model provider: '{provider}'. Supported providers: {supported}")
169+
170+
from .. import models
171+
172+
model_cls = getattr(models, class_name)
173+
return model_cls.from_dict(config)
174+
53175

54176
def config_to_agent(config: str | dict[str, Any], **kwargs: dict[str, Any]) -> Any:
55177
"""Create an Agent from a configuration file or dictionary.
@@ -83,6 +205,12 @@ def config_to_agent(config: str | dict[str, Any], **kwargs: dict[str, Any]) -> A
83205
Create agent from dictionary:
84206
>>> config = {"model": "anthropic.claude-3-5-sonnet-20241022-v2:0", "tools": ["calculator"]}
85207
>>> agent = config_to_agent(config)
208+
209+
Create agent with object model config:
210+
>>> config = {
211+
... "model": {"provider": "openai", "model_id": "gpt-4o", "client_args": {"api_key": "$OPENAI_API_KEY"}}
212+
... }
213+
>>> agent = config_to_agent(config)
86214
"""
87215
# Parse configuration
88216
if isinstance(config, str):
@@ -114,11 +242,20 @@ def config_to_agent(config: str | dict[str, Any], **kwargs: dict[str, Any]) -> A
114242
raise ValueError(f"Configuration validation error at {error_path}: {e.message}") from e
115243

116244
# Prepare Agent constructor arguments
117-
agent_kwargs = {}
245+
agent_kwargs: dict[str, Any] = {}
246+
247+
# Handle model field — string vs object format
248+
model_value = config_dict.get("model")
249+
if isinstance(model_value, dict):
250+
# Object format: resolve env vars and create Model instance via factory
251+
resolved_config = _resolve_env_vars(model_value)
252+
agent_kwargs["model"] = _create_model_from_dict(resolved_config)
253+
elif model_value is not None:
254+
# String format (backward compat): pass directly as model_id to Agent
255+
agent_kwargs["model"] = model_value
118256

119-
# Map configuration keys to Agent constructor parameters
257+
# Map remaining configuration keys to Agent constructor parameters
120258
config_mapping = {
121-
"model": "model",
122259
"prompt": "system_prompt",
123260
"tools": "tools",
124261
"name": "name",

src/strands/models/bedrock.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,32 @@ class BedrockConfig(TypedDict, total=False):
127127
temperature: float | None
128128
top_p: float | None
129129

130+
@classmethod
131+
def from_dict(cls, config: dict[str, Any]) -> "BedrockModel":
132+
"""Create a BedrockModel from a configuration dictionary.
133+
134+
Handles extraction of ``region_name``, ``endpoint_url``, and conversion of
135+
``boto_client_config`` from a plain dict to ``botocore.config.Config``.
136+
137+
Args:
138+
config: Model configuration dictionary.
139+
140+
Returns:
141+
A configured BedrockModel instance.
142+
"""
143+
kwargs: dict[str, Any] = {}
144+
145+
if "region_name" in config:
146+
kwargs["region_name"] = config.pop("region_name")
147+
if "endpoint_url" in config:
148+
kwargs["endpoint_url"] = config.pop("endpoint_url")
149+
if "boto_client_config" in config:
150+
raw = config.pop("boto_client_config")
151+
kwargs["boto_client_config"] = BotocoreConfig(**raw) if isinstance(raw, dict) else raw
152+
153+
kwargs.update(config)
154+
return cls(**kwargs)
155+
130156
def __init__(
131157
self,
132158
*,

src/strands/models/llamacpp.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,26 @@ class LlamaCppConfig(TypedDict, total=False):
131131
model_id: str
132132
params: dict[str, Any] | None
133133

134+
@classmethod
135+
def from_dict(cls, config: dict[str, Any]) -> "LlamaCppModel":
136+
"""Create a LlamaCppModel from a configuration dictionary.
137+
138+
Handles extraction of ``base_url`` and ``timeout`` as separate constructor parameters.
139+
140+
Args:
141+
config: Model configuration dictionary.
142+
143+
Returns:
144+
A configured LlamaCppModel instance.
145+
"""
146+
kwargs: dict[str, Any] = {}
147+
if "base_url" in config:
148+
kwargs["base_url"] = config.pop("base_url")
149+
if "timeout" in config:
150+
kwargs["timeout"] = config.pop("timeout")
151+
kwargs.update(config)
152+
return cls(**kwargs)
153+
134154
def __init__(
135155
self,
136156
base_url: str = "http://localhost:8080",

src/strands/models/mistral.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,28 @@ class MistralConfig(TypedDict, total=False):
5353
top_p: float | None
5454
stream: bool | None
5555

56+
@classmethod
57+
def from_dict(cls, config: dict[str, Any]) -> "MistralModel":
58+
"""Create a MistralModel from a configuration dictionary.
59+
60+
Handles extraction of ``api_key`` and ``client_args`` as separate constructor parameters.
61+
62+
Args:
63+
config: Model configuration dictionary.
64+
65+
Returns:
66+
A configured MistralModel instance.
67+
"""
68+
api_key = config.pop("api_key", None)
69+
client_args = config.pop("client_args", None)
70+
kwargs: dict[str, Any] = {}
71+
if api_key is not None:
72+
kwargs["api_key"] = api_key
73+
if client_args is not None:
74+
kwargs["client_args"] = client_args
75+
kwargs.update(config)
76+
return cls(**kwargs)
77+
5678
def __init__(
5779
self,
5880
api_key: str | None = None,

src/strands/models/model.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""Abstract base class for Agent model providers."""
22

3+
from __future__ import annotations
4+
35
import abc
46
import logging
57
from collections.abc import AsyncGenerator, AsyncIterable
@@ -51,6 +53,27 @@ def stateful(self) -> bool:
5153
"""
5254
return False
5355

56+
@classmethod
57+
def from_dict(cls, config: dict[str, Any]) -> Model:
58+
"""Create a Model instance from a configuration dictionary.
59+
60+
The default implementation extracts ``client_args`` (if present) and passes
61+
all remaining keys as keyword arguments to the constructor. Subclasses with
62+
non-standard constructor signatures should override this method.
63+
64+
Args:
65+
config: Provider-specific configuration dictionary.
66+
67+
Returns:
68+
A configured Model instance.
69+
"""
70+
client_args = config.pop("client_args", None)
71+
kwargs: dict[str, Any] = {}
72+
if client_args is not None:
73+
kwargs["client_args"] = client_args
74+
kwargs.update(config)
75+
return cls(**kwargs)
76+
5477
@abc.abstractmethod
5578
# pragma: no cover
5679
def update_config(self, **model_config: Any) -> None:

src/strands/models/ollama.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,27 @@ class OllamaConfig(TypedDict, total=False):
5656
temperature: float | None
5757
top_p: float | None
5858

59+
@classmethod
60+
def from_dict(cls, config: dict[str, Any]) -> "OllamaModel":
61+
"""Create an OllamaModel from a configuration dictionary.
62+
63+
Handles extraction of ``host`` as a positional argument and mapping of
64+
``client_args`` to the ``ollama_client_args`` constructor parameter.
65+
66+
Args:
67+
config: Model configuration dictionary.
68+
69+
Returns:
70+
A configured OllamaModel instance.
71+
"""
72+
host = config.pop("host", None)
73+
client_args = config.pop("client_args", None)
74+
kwargs: dict[str, Any] = {}
75+
if client_args is not None:
76+
kwargs["ollama_client_args"] = client_args
77+
kwargs.update(config)
78+
return cls(host, **kwargs)
79+
5980
def __init__(
6081
self,
6182
host: str | None,

src/strands/models/sagemaker.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,27 @@ class SageMakerAIEndpointConfig(TypedDict, total=False):
133133
target_variant: str | None | None
134134
additional_args: dict[str, Any] | None
135135

136+
@classmethod
137+
def from_dict(cls, config: dict[str, Any]) -> "SageMakerAIModel":
138+
"""Create a SageMakerAIModel from a configuration dictionary.
139+
140+
Handles extraction of ``endpoint_config``, ``payload_config``, and conversion of
141+
``boto_client_config`` from a plain dict to ``botocore.config.Config``.
142+
143+
Args:
144+
config: Model configuration dictionary.
145+
146+
Returns:
147+
A configured SageMakerAIModel instance.
148+
"""
149+
kwargs: dict[str, Any] = {}
150+
kwargs["endpoint_config"] = config.pop("endpoint_config", {})
151+
kwargs["payload_config"] = config.pop("payload_config", {})
152+
if "boto_client_config" in config:
153+
raw = config.pop("boto_client_config")
154+
kwargs["boto_client_config"] = BotocoreConfig(**raw) if isinstance(raw, dict) else raw
155+
return cls(**kwargs)
156+
136157
def __init__(
137158
self,
138159
endpoint_config: SageMakerAIEndpointConfig,

0 commit comments

Comments
 (0)