Skip to content

Commit a390a41

Browse files
committed
Fix loading princing error when using as a lib
Signed-off-by: kerthcet <kerthcet@gmail.com>
1 parent abe8591 commit a390a41

3 files changed

Lines changed: 254 additions & 8 deletions

File tree

alphatrion/utils/pricing.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""LLM pricing utilities for cost calculation."""
22

33
import logging
4+
from importlib import resources
45
from pathlib import Path
56
from typing import Any
67

@@ -23,14 +24,36 @@ def load_pricing_config() -> dict[str, Any]:
2324
if _PRICING_CACHE is not None:
2425
return _PRICING_CACHE
2526

26-
config_path = Path(__file__).parent.parent.parent / "config" / "modelspec.yaml"
27-
2827
try:
29-
with open(config_path) as f:
30-
config = yaml.safe_load(f)
31-
_PRICING_CACHE = config
32-
logger.info(f"Loaded pricing config from {config_path}")
33-
return config
28+
# Try to load from package resources (when installed)
29+
try:
30+
if hasattr(resources, "files"):
31+
# Python 3.9+
32+
config_file = resources.files("alphatrion").joinpath(
33+
"config/modelspec.yaml"
34+
)
35+
config_data = config_file.read_text()
36+
else:
37+
# Python 3.7-3.8 fallback
38+
import importlib.resources as pkg_resources
39+
40+
config_data = pkg_resources.read_text(
41+
"alphatrion.config", "modelspec.yaml"
42+
)
43+
44+
config = yaml.safe_load(config_data)
45+
logger.info("Loaded pricing config from package resources")
46+
except (FileNotFoundError, ModuleNotFoundError):
47+
# Fall back to relative path (for development)
48+
config_path = (
49+
Path(__file__).parent.parent.parent / "config" / "modelspec.yaml"
50+
)
51+
with open(config_path) as f:
52+
config = yaml.safe_load(f)
53+
logger.info(f"Loaded pricing config from {config_path}")
54+
55+
_PRICING_CACHE = config
56+
return config
3457
except Exception as e:
3558
logger.error(f"Failed to load pricing config: {e}")
3659
raise

pyproject.toml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,12 @@ alphatrion = "alphatrion.server.cmd.main:main"
5151
requires = ["hatchling"]
5252
build-backend = "hatchling.build"
5353

54+
[tool.hatch.build.targets.wheel]
55+
packages = ["alphatrion"]
56+
57+
[tool.hatch.build.targets.wheel.force-include]
58+
"config/modelspec.yaml" = "alphatrion/config/modelspec.yaml"
59+
5460
# Configuration for ruff linter and formatter
5561

5662
[tool.ruff]
@@ -82,4 +88,4 @@ quote-style = "double"
8288
indent-style = "space"
8389

8490
[tool.ruff.lint.per-file-ignores]
85-
"tests/*" = ["PLR2004"]
91+
"tests/*" = ["PLR2004", "B017"]

tests/unit/utils/test_pricing.py

Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
"""Tests for pricing utilities."""
2+
3+
from unittest import mock
4+
5+
import pytest
6+
7+
from alphatrion.utils import pricing
8+
9+
10+
@pytest.fixture(autouse=True)
11+
def reset_pricing_cache():
12+
"""Reset pricing cache before each test."""
13+
pricing._PRICING_CACHE = None
14+
yield
15+
pricing._PRICING_CACHE = None
16+
17+
18+
def test_load_pricing_config_dev_mode():
19+
"""Test loading pricing config in development mode (relative path)."""
20+
# Clear cache to force reload
21+
pricing._PRICING_CACHE = None
22+
23+
config = pricing.load_pricing_config()
24+
25+
assert isinstance(config, dict)
26+
assert "anthropic" in config or "deepinfra" in config
27+
# Verify it's cached
28+
assert pricing._PRICING_CACHE is not None
29+
30+
31+
def test_load_pricing_config_cached():
32+
"""Test that pricing config is cached after first load."""
33+
# First load
34+
config1 = pricing.load_pricing_config()
35+
36+
# Second load should return cached value
37+
config2 = pricing.load_pricing_config()
38+
39+
assert config1 is config2
40+
41+
42+
def test_load_pricing_config_as_installed_package(tmp_path, monkeypatch):
43+
"""Test loading pricing config when installed as a library."""
44+
# Create a mock package structure
45+
mock_config_content = """
46+
anthropic:
47+
models: []
48+
49+
deepinfra:
50+
models:
51+
test-model:
52+
description: "Test model"
53+
input_tokens_price: 0.1
54+
output_tokens_price: 0.5
55+
cache_read_input_tokens_price: 0.05
56+
cache_creation_input_tokens_price: 0.1
57+
"""
58+
59+
# Mock importlib.resources to simulate installed package
60+
mock_file = mock.MagicMock()
61+
mock_file.read_text.return_value = mock_config_content
62+
63+
mock_files = mock.MagicMock()
64+
mock_files.joinpath.return_value = mock_file
65+
66+
with mock.patch(
67+
"alphatrion.utils.pricing.resources.files", return_value=mock_files
68+
):
69+
config = pricing.load_pricing_config()
70+
71+
assert isinstance(config, dict)
72+
assert "anthropic" in config
73+
assert "deepinfra" in config
74+
assert "test-model" in config["deepinfra"]["models"]
75+
mock_files.joinpath.assert_called_once_with("config/modelspec.yaml")
76+
77+
78+
def test_load_pricing_config_fallback_to_relative_path(monkeypatch):
79+
"""Test fallback to relative path when package resources fail."""
80+
81+
def mock_files_error(*args, **kwargs):
82+
raise ModuleNotFoundError("Package not found")
83+
84+
with mock.patch(
85+
"alphatrion.utils.pricing.resources.files", side_effect=mock_files_error
86+
):
87+
# Should fall back to relative path
88+
config = pricing.load_pricing_config()
89+
90+
assert isinstance(config, dict)
91+
# Should successfully load from relative path
92+
assert "anthropic" in config or "deepinfra" in config
93+
94+
95+
def test_load_pricing_config_missing_file_raises_error(tmp_path, monkeypatch):
96+
"""Test that missing config file raises appropriate error."""
97+
98+
def mock_files_error(*args, **kwargs):
99+
raise FileNotFoundError("Config not found")
100+
101+
# Mock both package resources and file path to fail
102+
with mock.patch(
103+
"alphatrion.utils.pricing.resources.files", side_effect=mock_files_error
104+
):
105+
# Also mock Path to point to non-existent location
106+
with mock.patch("alphatrion.utils.pricing.Path") as mock_path:
107+
mock_path.return_value.parent.parent.parent.__truediv__.return_value.__truediv__.return_value = (
108+
tmp_path / "nonexistent.yaml"
109+
)
110+
111+
with pytest.raises(Exception):
112+
pricing.load_pricing_config()
113+
114+
115+
def test_get_model_pricing():
116+
"""Test getting pricing for a specific model."""
117+
# First ensure config is loaded
118+
config = pricing.load_pricing_config()
119+
120+
# Find a model from the loaded config
121+
provider = None
122+
model = None
123+
for prov, prov_data in config.items():
124+
models = prov_data.get("models", {})
125+
if models:
126+
provider = prov
127+
model = next(iter(models.keys()))
128+
break
129+
130+
if provider and model:
131+
pricing_info = pricing.get_model_pricing(provider, model)
132+
133+
assert isinstance(pricing_info, dict)
134+
assert "input_tokens_price" in pricing_info
135+
assert "output_tokens_price" in pricing_info
136+
assert "cache_creation_input_tokens_price" in pricing_info
137+
assert "cache_read_input_tokens_price" in pricing_info
138+
139+
140+
def test_get_model_pricing_fallback_to_default():
141+
"""Test fallback to default pricing for unknown model."""
142+
pricing_info = pricing.get_model_pricing("unknown-provider", "unknown-model")
143+
144+
assert isinstance(pricing_info, dict)
145+
assert pricing_info["input_tokens_price"] == 3.3
146+
assert pricing_info["output_tokens_price"] == 16.5
147+
assert pricing_info["cache_creation_input_tokens_price"] == 3.3
148+
assert pricing_info["cache_read_input_tokens_price"] == 3.3
149+
150+
151+
def test_calculate_cost():
152+
"""Test cost calculation."""
153+
cost = pricing.calculate_cost(
154+
provider="deepinfra",
155+
model="test-model",
156+
input_tokens=1_000_000, # 1M tokens
157+
output_tokens=500_000, # 0.5M tokens
158+
cache_creation_input_tokens=200_000, # 0.2M tokens
159+
cache_read_input_tokens=100_000, # 0.1M tokens
160+
)
161+
162+
assert isinstance(cost, dict)
163+
assert "total_cost" in cost
164+
assert "input_cost" in cost
165+
assert "output_cost" in cost
166+
assert "cache_creation_input_cost" in cost
167+
assert "cache_read_input_cost" in cost
168+
169+
# All costs should be non-negative
170+
assert cost["total_cost"] >= 0
171+
assert cost["input_cost"] >= 0
172+
assert cost["output_cost"] >= 0
173+
174+
# Total should be sum of all components
175+
expected_total = (
176+
cost["input_cost"]
177+
+ cost["output_cost"]
178+
+ cost["cache_creation_input_cost"]
179+
+ cost["cache_read_input_cost"]
180+
)
181+
assert abs(cost["total_cost"] - expected_total) < 0.00000001
182+
183+
184+
def test_calculate_cost_zero_tokens():
185+
"""Test cost calculation with zero tokens."""
186+
cost = pricing.calculate_cost(
187+
provider="deepinfra",
188+
model="test-model",
189+
input_tokens=0,
190+
output_tokens=0,
191+
cache_creation_input_tokens=0,
192+
cache_read_input_tokens=0,
193+
)
194+
195+
assert cost["total_cost"] == 0
196+
assert cost["input_cost"] == 0
197+
assert cost["output_cost"] == 0
198+
assert cost["cache_creation_input_cost"] == 0
199+
assert cost["cache_read_input_cost"] == 0
200+
201+
202+
def test_calculate_cost_precision():
203+
"""Test that costs are rounded to 8 decimal places."""
204+
cost = pricing.calculate_cost(
205+
provider="deepinfra",
206+
model="test-model",
207+
input_tokens=1, # Very small number
208+
output_tokens=1,
209+
)
210+
211+
# Check that all values are rounded to 8 decimal places
212+
for key, value in cost.items():
213+
# Convert to string and check decimal places
214+
str_value = str(value)
215+
if "." in str_value:
216+
decimal_places = len(str_value.split(".")[1])
217+
assert decimal_places <= 8, f"{key} has {decimal_places} decimal places"

0 commit comments

Comments
 (0)