Skip to content

Commit 05c66c2

Browse files
authored
fix(v1): record final/context token usage at write time (#1525)
1 parent daccece commit 05c66c2

3 files changed

Lines changed: 42 additions & 6 deletions

File tree

tests/test_v1_runtime_lifecycle.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -732,7 +732,12 @@ async def test_v1_records_default_metrics_usage_and_timing() -> None:
732732
state = await harness.run(task)
733733

734734
assert state["metrics"]["num_turns"] == 1.0
735-
assert state["token_usage"] == {"input_tokens": 11.0, "output_tokens": 7.0}
735+
assert state["token_usage"] == {
736+
"input_tokens": 11.0,
737+
"output_tokens": 7.0,
738+
"final_output_tokens": 7.0,
739+
"final_input_tokens": 11.0,
740+
}
736741
assert state["usage"] == state["token_usage"]
737742
assert state["timing"]["total"] > 0.0
738743
assert state["timing"]["generation"]["duration"] > 0.0

verifiers/utils/save_utils.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -243,12 +243,30 @@ def state_to_output(
243243
"input_tokens": usage.get("input_tokens", 0.0),
244244
"output_tokens": usage.get("output_tokens", 0.0),
245245
}
246-
# Add context token metrics from trajectory
247-
trajectory = state.get("trajectory", [])
248-
if isinstance(trajectory, list):
249-
from verifiers.utils.usage_utils import compute_context_token_metrics
246+
# Context ("final") token metrics. v1 records these at write time from
247+
# the live Response (the serialized trajectory can't be re-derived since
248+
# responses are plain dicts), so prefer them when present. Classic envs
249+
# keep live Response objects in the trajectory, so recompute there.
250+
raw_usage = state.get("token_usage")
251+
final_output = (
252+
raw_usage.get("final_output_tokens")
253+
if isinstance(raw_usage, Mapping)
254+
else None
255+
)
256+
final_input = (
257+
raw_usage.get("final_input_tokens")
258+
if isinstance(raw_usage, Mapping)
259+
else None
260+
)
261+
if final_output is not None and final_input is not None:
262+
token_usage["final_output_tokens"] = float(final_output)
263+
token_usage["final_input_tokens"] = float(final_input)
264+
else:
265+
trajectory = state.get("trajectory", [])
266+
if isinstance(trajectory, list):
267+
from verifiers.utils.usage_utils import compute_context_token_metrics
250268

251-
token_usage.update(compute_context_token_metrics(trajectory))
269+
token_usage.update(compute_context_token_metrics(trajectory))
252270
output["token_usage"] = token_usage
253271

254272
# sanitize messages (handle None for error cases)

verifiers/v1/utils/usage_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,17 @@ def record_response_usage(state: State, response: Response) -> None:
1818
usage["output_tokens"] = float(usage.get("output_tokens", 0.0)) + float(
1919
output_tokens
2020
)
21+
# Context ("final") token metrics, accumulated at write time from the live
22+
# Response. v1 serializes trajectory responses to plain dicts, so they can't
23+
# be recomputed from the trajectory afterward (the isinstance(Response) gate
24+
# in compute_context_token_metrics fails). Mirror that helper's formula for a
25+
# linear rollout: final_output is the running sum of completions; final_input
26+
# is the latest step's full context minus that sum.
27+
usage["final_output_tokens"] = float(usage.get("final_output_tokens", 0.0)) + float(
28+
output_tokens
29+
)
30+
last_step_total = float(input_tokens) + float(output_tokens)
31+
usage["final_input_tokens"] = max(
32+
0.0, last_step_total - usage["final_output_tokens"]
33+
)
2134
state["usage"] = usage

0 commit comments

Comments
 (0)