Skip to content

Commit c723e52

Browse files
authored
feat: add context_window_limit to model configs (#2176)
1 parent 724b591 commit c723e52

18 files changed

Lines changed: 105 additions & 34 deletions

src/strands/models/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@
77

88
from . import bedrock, model
99
from .bedrock import BedrockModel
10-
from .model import CacheConfig, Model
10+
from .model import BaseModelConfig, CacheConfig, Model
1111

1212
__all__ = [
1313
"bedrock",
1414
"model",
15+
"BaseModelConfig",
1516
"BedrockModel",
1617
"CacheConfig",
1718
"Model",

src/strands/models/anthropic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import logging
99
import mimetypes
1010
from collections.abc import AsyncGenerator
11-
from typing import Any, TypedDict, TypeVar, cast
11+
from typing import Any, TypeVar, cast
1212

1313
import anthropic
1414
from pydantic import BaseModel
@@ -21,7 +21,7 @@
2121
from ..types.streaming import StreamEvent
2222
from ..types.tools import ToolChoice, ToolChoiceToolDict, ToolSpec
2323
from ._validation import _has_location_source, validate_config_keys
24-
from .model import Model
24+
from .model import BaseModelConfig, Model
2525

2626
logger = logging.getLogger(__name__)
2727

@@ -46,7 +46,7 @@ class AnthropicModel(Model):
4646
"input and output tokens exceed your context limit",
4747
}
4848

49-
class AnthropicConfig(TypedDict, total=False):
49+
class AnthropicConfig(BaseModelConfig, total=False):
5050
"""Configuration options for Anthropic models.
5151
5252
Attributes:

src/strands/models/bedrock.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from botocore.config import Config as BotocoreConfig
1616
from botocore.exceptions import ClientError
1717
from pydantic import BaseModel
18-
from typing_extensions import TypedDict, Unpack, override
18+
from typing_extensions import Unpack, override
1919

2020
from strands.types.media import S3Location, SourceLocation
2121

@@ -31,7 +31,7 @@
3131
from ..types.streaming import CitationsDelta, StreamEvent
3232
from ..types.tools import ToolChoice, ToolSpec
3333
from ._validation import validate_config_keys
34-
from .model import CacheConfig, Model
34+
from .model import BaseModelConfig, CacheConfig, Model
3535

3636
logger = logging.getLogger(__name__)
3737

@@ -69,7 +69,7 @@ class BedrockModel(Model):
6969
- Context window overflow detection
7070
"""
7171

72-
class BedrockConfig(TypedDict, total=False):
72+
class BedrockConfig(BaseModelConfig, total=False):
7373
"""Configuration options for Bedrock models.
7474
7575
Attributes:

src/strands/models/gemini.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import mimetypes
1010
import secrets
1111
from collections.abc import AsyncGenerator
12-
from typing import Any, TypedDict, TypeVar, cast
12+
from typing import Any, TypeVar, cast
1313

1414
import pydantic
1515
from google import genai
@@ -20,7 +20,7 @@
2020
from ..types.streaming import StreamEvent
2121
from ..types.tools import ToolChoice, ToolSpec
2222
from ._validation import _has_location_source, validate_config_keys
23-
from .model import Model
23+
from .model import BaseModelConfig, Model
2424

2525
logger = logging.getLogger(__name__)
2626

@@ -33,7 +33,7 @@ class GeminiModel(Model):
3333
- Docs: https://ai.google.dev/api
3434
"""
3535

36-
class GeminiConfig(TypedDict, total=False):
36+
class GeminiConfig(BaseModelConfig, total=False):
3737
"""Configuration options for Gemini models.
3838
3939
Attributes:

src/strands/models/litellm.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import json
77
import logging
88
from collections.abc import AsyncGenerator
9-
from typing import Any, TypedDict, TypeVar, cast
9+
from typing import Any, TypeVar, cast
1010

1111
import litellm
1212
from litellm.exceptions import ContextWindowExceededError
@@ -21,6 +21,7 @@
2121
from ..types.streaming import MetadataEvent, StreamEvent
2222
from ..types.tools import ToolChoice, ToolSpec, ToolUse
2323
from ._validation import validate_config_keys
24+
from .model import BaseModelConfig
2425
from .openai import OpenAIModel
2526

2627
logger = logging.getLogger(__name__)
@@ -35,7 +36,7 @@
3536
class LiteLLMModel(OpenAIModel):
3637
"""LiteLLM model provider implementation."""
3738

38-
class LiteLLMConfig(TypedDict, total=False):
39+
class LiteLLMConfig(BaseModelConfig, total=False):
3940
"""Configuration options for LiteLLM models.
4041
4142
Attributes:

src/strands/models/llamaapi.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,14 @@
1414
import llama_api_client
1515
from llama_api_client import LlamaAPIClient
1616
from pydantic import BaseModel
17-
from typing_extensions import TypedDict, Unpack, override
17+
from typing_extensions import Unpack, override
1818

1919
from ..types.content import ContentBlock, Messages
2020
from ..types.exceptions import ModelThrottledException
2121
from ..types.streaming import StreamEvent, Usage
2222
from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse
2323
from ._validation import _has_location_source, validate_config_keys, warn_on_tool_choice_not_supported
24-
from .model import Model
24+
from .model import BaseModelConfig, Model
2525

2626
logger = logging.getLogger(__name__)
2727

@@ -31,7 +31,7 @@
3131
class LlamaAPIModel(Model):
3232
"""Llama API model provider implementation."""
3333

34-
class LlamaConfig(TypedDict, total=False):
34+
class LlamaConfig(BaseModelConfig, total=False):
3535
"""Configuration options for Llama API models.
3636
3737
Attributes:

src/strands/models/llamacpp.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from collections.abc import AsyncGenerator
1818
from typing import (
1919
Any,
20-
TypedDict,
2120
TypeVar,
2221
cast,
2322
)
@@ -31,7 +30,7 @@
3130
from ..types.streaming import StreamEvent
3231
from ..types.tools import ToolChoice, ToolSpec
3332
from ._validation import _has_location_source, validate_config_keys, warn_on_tool_choice_not_supported
34-
from .model import Model
33+
from .model import BaseModelConfig, Model
3534

3635
logger = logging.getLogger(__name__)
3736

@@ -86,7 +85,7 @@ class LlamaCppModel(Model):
8685
>>> response = agent(image_content)
8786
"""
8887

89-
class LlamaCppConfig(TypedDict, total=False):
88+
class LlamaCppConfig(BaseModelConfig, total=False):
9089
"""Configuration options for llama.cpp models.
9190
9291
Attributes:

src/strands/models/mistral.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,14 @@
1111

1212
import mistralai
1313
from pydantic import BaseModel
14-
from typing_extensions import TypedDict, Unpack, override
14+
from typing_extensions import Unpack, override
1515

1616
from ..types.content import ContentBlock, Messages
1717
from ..types.exceptions import ModelThrottledException
1818
from ..types.streaming import StopReason, StreamEvent
1919
from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse
2020
from ._validation import _has_location_source, validate_config_keys, warn_on_tool_choice_not_supported
21-
from .model import Model
21+
from .model import BaseModelConfig, Model
2222

2323
logger = logging.getLogger(__name__)
2424

@@ -36,7 +36,7 @@ class MistralModel(Model):
3636
- System prompts
3737
"""
3838

39-
class MistralConfig(TypedDict, total=False):
39+
class MistralConfig(BaseModelConfig, total=False):
4040
"""Configuration parameters for Mistral models.
4141
4242
Attributes:

src/strands/models/model.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import logging
55
from collections.abc import AsyncGenerator, AsyncIterable
66
from dataclasses import dataclass
7-
from typing import TYPE_CHECKING, Any, Literal, TypeVar
7+
from typing import TYPE_CHECKING, Any, Literal, TypedDict, TypeVar
88

99
from pydantic import BaseModel
1010

@@ -22,6 +22,17 @@
2222
T = TypeVar("T", bound=BaseModel)
2323

2424

25+
class BaseModelConfig(TypedDict, total=False):
26+
"""Base configuration shared by all model providers.
27+
28+
Attributes:
29+
context_window_limit: Maximum context window size in tokens for the model.
30+
This value represents the total token capacity shared between input and output.
31+
"""
32+
33+
context_window_limit: int | None
34+
35+
2536
@dataclass
2637
class CacheConfig:
2738
"""Configuration for prompt caching.
@@ -51,6 +62,16 @@ def stateful(self) -> bool:
5162
"""
5263
return False
5364

65+
@property
66+
def context_window_limit(self) -> int | None:
67+
"""Maximum context window size in tokens, or None if not configured."""
68+
config = self.get_config()
69+
return (
70+
config.get("context_window_limit")
71+
if isinstance(config, dict)
72+
else getattr(config, "context_window_limit", None)
73+
)
74+
5475
@abc.abstractmethod
5576
# pragma: no cover
5677
def update_config(self, **model_config: Any) -> None:

src/strands/models/ollama.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@
1010

1111
import ollama
1212
from pydantic import BaseModel
13-
from typing_extensions import TypedDict, Unpack, override
13+
from typing_extensions import Unpack, override
1414

1515
from ..types.content import ContentBlock, Messages
1616
from ..types.streaming import StopReason, StreamEvent
1717
from ..types.tools import ToolChoice, ToolSpec
1818
from ._validation import _has_location_source, validate_config_keys, warn_on_tool_choice_not_supported
19-
from .model import Model
19+
from .model import BaseModelConfig, Model
2020

2121
logger = logging.getLogger(__name__)
2222

@@ -33,7 +33,7 @@ class OllamaModel(Model):
3333
- Tool/function calling
3434
"""
3535

36-
class OllamaConfig(TypedDict, total=False):
36+
class OllamaConfig(BaseModelConfig, total=False):
3737
"""Configuration parameters for Ollama models.
3838
3939
Attributes:

0 commit comments

Comments
 (0)