Skip to content

Commit 6e1ea26

Browse files
chore: move resolve_default_pricing to ModelPricing
1 parent aa2cf3c commit 6e1ea26

4 files changed

Lines changed: 97 additions & 68 deletions

File tree

src/askui/model_providers/anthropic_vlm_provider.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
)
1717
from askui.models.shared.prompts import SystemPrompt
1818
from askui.models.shared.tools import ToolCollection
19-
from askui.utils.model_pricing import ModelPricing, resolve_default_pricing
19+
from askui.utils.model_pricing import ModelPricing
2020

2121
_DEFAULT_MODEL_ID = "claude-sonnet-4-6"
2222

@@ -80,17 +80,11 @@ def __init__(
8080
base_url=base_url,
8181
auth_token=auth_token,
8282
)
83-
self._pricing: ModelPricing | None
84-
if (
85-
input_cost_per_million_tokens is not None
86-
and output_cost_per_million_tokens is not None
87-
):
88-
self._pricing = ModelPricing(
89-
input_cost_per_million_tokens=input_cost_per_million_tokens,
90-
output_cost_per_million_tokens=output_cost_per_million_tokens,
91-
)
92-
else:
93-
self._pricing = resolve_default_pricing(self._model_id_value)
83+
self._pricing = ModelPricing.for_model(
84+
self._model_id_value,
85+
input_cost_per_million_tokens=input_cost_per_million_tokens,
86+
output_cost_per_million_tokens=output_cost_per_million_tokens,
87+
)
9488

9589
@property
9690
@override

src/askui/model_providers/askui_vlm_provider.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
)
1818
from askui.models.shared.prompts import SystemPrompt
1919
from askui.models.shared.tools import ToolCollection
20-
from askui.utils.model_pricing import ModelPricing, resolve_default_pricing
20+
from askui.utils.model_pricing import ModelPricing
2121

2222
_DEFAULT_MODEL_ID = "claude-sonnet-4-6"
2323

@@ -72,17 +72,11 @@ def __init__(
7272
model_id or os.environ.get("VLM_PROVIDER_MODEL_ID") or _DEFAULT_MODEL_ID
7373
)
7474
self._injected_client = client
75-
self._pricing: ModelPricing | None
76-
if (
77-
input_cost_per_million_tokens is not None
78-
and output_cost_per_million_tokens is not None
79-
):
80-
self._pricing = ModelPricing(
81-
input_cost_per_million_tokens=input_cost_per_million_tokens,
82-
output_cost_per_million_tokens=output_cost_per_million_tokens,
83-
)
84-
else:
85-
self._pricing = resolve_default_pricing(self._model_id_value)
75+
self._pricing = ModelPricing.for_model(
76+
self._model_id_value,
77+
input_cost_per_million_tokens=input_cost_per_million_tokens,
78+
output_cost_per_million_tokens=output_cost_per_million_tokens,
79+
)
8680

8781
@property
8882
@override

src/askui/utils/model_pricing.py

Lines changed: 58 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from pydantic import BaseModel
44

5+
_DEFAULT_PRICING: dict[str, "ModelPricing"] = {}
6+
57

68
class ModelPricing(BaseModel):
79
"""Cost per 1 million tokens for a model.
@@ -16,42 +18,62 @@ class ModelPricing(BaseModel):
1618
output_cost_per_million_tokens: float
1719
currency: str = "USD"
1820

21+
@classmethod
22+
def for_model(
23+
cls,
24+
model_id: str,
25+
input_cost_per_million_tokens: float | None = None,
26+
output_cost_per_million_tokens: float | None = None,
27+
) -> "ModelPricing | None":
28+
"""Resolve pricing for a model.
1929
20-
_DEFAULT_PRICING: dict[str, ModelPricing] = {
21-
"claude-haiku-4-5-20251001": ModelPricing(
22-
input_cost_per_million_tokens=1.0,
23-
output_cost_per_million_tokens=5.0,
24-
),
25-
"claude-sonnet-4-5-20250929": ModelPricing(
26-
input_cost_per_million_tokens=3.0,
27-
output_cost_per_million_tokens=15.0,
28-
),
29-
"claude-opus-4-5-20251101": ModelPricing(
30-
input_cost_per_million_tokens=5.0,
31-
output_cost_per_million_tokens=25.0,
32-
),
33-
"claude-sonnet-4-6": ModelPricing(
34-
input_cost_per_million_tokens=3.0,
35-
output_cost_per_million_tokens=15.0,
36-
),
37-
"claude-opus-4-6": ModelPricing(
38-
input_cost_per_million_tokens=5.0,
39-
output_cost_per_million_tokens=25.0,
40-
),
41-
}
42-
43-
44-
def resolve_default_pricing(model_id: str) -> ModelPricing | None:
45-
"""Resolve default pricing for a model ID by prefix matching.
46-
47-
Tries exact match first, then longest-prefix match.
30+
If both cost parameters are provided, creates a ``ModelPricing``
31+
with those values. Otherwise, looks up built-in defaults by
32+
``model_id``.
4833
49-
Args:
50-
model_id (str): The model identifier.
34+
Args:
35+
model_id (str): The model identifier.
36+
input_cost_per_million_tokens (float | None, optional): Override
37+
cost in USD per 1M input tokens.
38+
output_cost_per_million_tokens (float | None, optional): Override
39+
cost in USD per 1M output tokens.
5140
52-
Returns:
53-
ModelPricing | None: Default pricing, or ``None`` if no match found.
54-
"""
55-
if model_id in _DEFAULT_PRICING:
56-
return _DEFAULT_PRICING[model_id]
57-
return None
41+
Returns:
42+
ModelPricing | None: Resolved pricing, or ``None`` if no match
43+
and no overrides provided.
44+
"""
45+
if (
46+
input_cost_per_million_tokens is not None
47+
and output_cost_per_million_tokens is not None
48+
):
49+
return cls(
50+
input_cost_per_million_tokens=input_cost_per_million_tokens,
51+
output_cost_per_million_tokens=output_cost_per_million_tokens,
52+
)
53+
return _DEFAULT_PRICING.get(model_id)
54+
55+
56+
_DEFAULT_PRICING.update(
57+
{
58+
"claude-haiku-4-5-20251001": ModelPricing(
59+
input_cost_per_million_tokens=1.0,
60+
output_cost_per_million_tokens=5.0,
61+
),
62+
"claude-sonnet-4-5-20250929": ModelPricing(
63+
input_cost_per_million_tokens=3.0,
64+
output_cost_per_million_tokens=15.0,
65+
),
66+
"claude-opus-4-5-20251101": ModelPricing(
67+
input_cost_per_million_tokens=5.0,
68+
output_cost_per_million_tokens=25.0,
69+
),
70+
"claude-sonnet-4-6": ModelPricing(
71+
input_cost_per_million_tokens=3.0,
72+
output_cost_per_million_tokens=15.0,
73+
),
74+
"claude-opus-4-6": ModelPricing(
75+
input_cost_per_million_tokens=5.0,
76+
output_cost_per_million_tokens=25.0,
77+
),
78+
}
79+
)

tests/unit/model_providers/test_model_pricing.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,36 +7,55 @@
77

88
from askui.models.shared.agent_message_param import UsageParam
99
from askui.models.shared.usage_tracking_callback import UsageTrackingCallback
10-
from askui.utils.model_pricing import ModelPricing, resolve_default_pricing
10+
from askui.utils.model_pricing import ModelPricing
1111

1212

13-
class TestResolveDefaultPricing:
13+
class TestModelPricingForModel:
1414
def test_exact_match_sonnet_4_6(self) -> None:
15-
pricing = resolve_default_pricing("claude-sonnet-4-6")
15+
pricing = ModelPricing.for_model("claude-sonnet-4-6")
1616
assert pricing is not None
1717
assert pricing.input_cost_per_million_tokens == 3.0
1818
assert pricing.output_cost_per_million_tokens == 15.0
1919

2020
def test_exact_match_opus_4_6(self) -> None:
21-
pricing = resolve_default_pricing("claude-opus-4-6")
21+
pricing = ModelPricing.for_model("claude-opus-4-6")
2222
assert pricing is not None
2323
assert pricing.input_cost_per_million_tokens == 5.0
2424
assert pricing.output_cost_per_million_tokens == 25.0
2525

2626
def test_exact_match_haiku(self) -> None:
27-
pricing = resolve_default_pricing("claude-haiku-4-5-20251001")
27+
pricing = ModelPricing.for_model("claude-haiku-4-5-20251001")
2828
assert pricing is not None
2929
assert pricing.input_cost_per_million_tokens == 1.0
3030
assert pricing.output_cost_per_million_tokens == 5.0
3131

3232
def test_unknown_model_returns_none(self) -> None:
33-
assert resolve_default_pricing("unknown-model-v1") is None
33+
assert ModelPricing.for_model("unknown-model-v1") is None
3434

3535
def test_empty_string_returns_none(self) -> None:
36-
assert resolve_default_pricing("") is None
36+
assert ModelPricing.for_model("") is None
3737

3838
def test_partial_model_id_returns_none(self) -> None:
39-
assert resolve_default_pricing("claude-sonnet-4") is None
39+
assert ModelPricing.for_model("claude-sonnet-4") is None
40+
41+
def test_override_costs(self) -> None:
42+
pricing = ModelPricing.for_model(
43+
"claude-sonnet-4-6",
44+
input_cost_per_million_tokens=99.0,
45+
output_cost_per_million_tokens=199.0,
46+
)
47+
assert pricing is not None
48+
assert pricing.input_cost_per_million_tokens == 99.0
49+
assert pricing.output_cost_per_million_tokens == 199.0
50+
51+
def test_override_costs_unknown_model(self) -> None:
52+
pricing = ModelPricing.for_model(
53+
"unknown-model",
54+
input_cost_per_million_tokens=1.0,
55+
output_cost_per_million_tokens=2.0,
56+
)
57+
assert pricing is not None
58+
assert pricing.input_cost_per_million_tokens == 1.0
4059

4160

4261
def _get_usage_dict(reporter_mock: MagicMock) -> dict[str, Any]:

0 commit comments

Comments
 (0)