Skip to content

Commit 13b03b9

Browse files
authored
Merge pull request #310 from zhanz5/fix/cost-calculation-model-aware
fix: make cost calculation model-aware instead of hardcoded to deepseek-chat
2 parents 6ae9ea8 + ab33513 commit 13b03b9

6 files changed

Lines changed: 38 additions & 5 deletions

File tree

agentic_security/http_spec.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import base64
2+
import json
23
from enum import Enum
34
from urllib.parse import urlparse
45

@@ -145,6 +146,18 @@ async def verify(self) -> httpx.Response:
145146

146147
fn = probe
147148

149+
@property
150+
def model_name(self) -> str:
151+
"""Extract the model name from the request body (JSON).
152+
153+
Returns the value of the 'model' field if present, otherwise 'unknown'.
154+
"""
155+
try:
156+
body_json = json.loads(self.body)
157+
return body_json.get("model", "unknown")
158+
except (json.JSONDecodeError, TypeError):
159+
return "unknown"
160+
148161
@property
149162
def modality(self) -> Modality:
150163
if self.has_image:

agentic_security/primitives/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def with_secrets(self, secrets) -> "Scan":
4242
class ScanResult(BaseModel):
4343
module: str
4444
tokens: float | int
45-
cost: float
45+
cost: float | None
4646
progress: float
4747
status: bool = False
4848
failureRate: float = 0.0

agentic_security/probe_actor/cost_module.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from agentic_security.logutils import logger
2+
13
# API pricing, USD per token. Values are dollars per 1M tokens / 1_000_000.
24
# Verified against vendor pricing pages on 2026-06-03.
35
PRICING = {
@@ -21,13 +23,19 @@
2123
DEFAULT_MODEL = "claude-sonnet"
2224

2325

24-
def calculate_cost(tokens: int, model: str = DEFAULT_MODEL) -> float:
26+
def calculate_cost(tokens: int, model: str = DEFAULT_MODEL) -> float | None:
2527
"""Calculate API cost in USD for a total token count.
2628
2729
Assumes a 1:1 input/output split, since callers only track a combined total.
30+
31+
Returns:
32+
float | None: Cost in USD, or None if the model pricing is unknown.
2833
"""
2934
if model not in PRICING:
30-
raise ValueError(f"Unknown model: {model}")
35+
logger.warning(
36+
f"Unknown model '{model}': pricing not available, cost will not be estimated."
37+
)
38+
return None
3139

3240
half = max(tokens, 0) / 2
3341
rates = PRICING[model]

agentic_security/probe_actor/fuzzer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,9 @@ async def scan_module(
273273

274274
failure_rate = module_failures / max(module_prompts, 1)
275275
failure_rates.append(failure_rate)
276-
cost = calculate_cost(tokens)
276+
cost = calculate_cost(
277+
tokens, model=getattr(request_factory, "model_name", "unknown")
278+
)
277279

278280
response_text = fuzzer_state.get_last_output(prompt) or ""
279281

@@ -557,7 +559,9 @@ async def perform_many_shot_scan(
557559

558560
failure_rate = module_failures / max(processed_prompts, 1)
559561
failure_rates.append(failure_rate)
560-
cost = calculate_cost(tokens)
562+
cost = calculate_cost(
563+
tokens, model=getattr(request_factory, "model_name", "unknown")
564+
)
561565

562566
yield ScanResult(
563567
module=module.dataset_name,

agentic_security/probe_data/audio_generator.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,10 @@ def __init__(self, llm_spec):
131131
if not llm_spec.has_audio:
132132
raise ValueError("LLMSpec must have an image")
133133

134+
@property
135+
def model_name(self) -> str:
136+
return self.llm_spec.model_name
137+
134138
async def probe(
135139
self, prompt: str, encoded_image: str = "", encoded_audio: str = "", files={}
136140
) -> httpx.Response:

agentic_security/probe_data/image_generator.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,10 @@ def __init__(self, llm_spec):
131131
if not llm_spec.has_image:
132132
raise ValueError("LLMSpec must have an image")
133133

134+
@property
135+
def model_name(self) -> str:
136+
return self.llm_spec.model_name
137+
134138
async def probe(
135139
self, prompt: str, encoded_image: str = "", encoded_audio: str = "", files={}
136140
) -> httpx.Response:

0 commit comments

Comments
 (0)