|
1 | 1 | """Tests for OptimizationClient.""" |
2 | 2 |
|
| 3 | +import dataclasses |
3 | 4 | import json |
4 | 5 | from typing import Any, Dict |
5 | 6 | from unittest.mock import AsyncMock, MagicMock, patch |
@@ -5176,6 +5177,52 @@ def test_skips_gracefully_when_units_differ_across_model_switch(self): |
5176 | 5177 | assert self.client._evaluate_cost(self._ctx(None)) is True |
5177 | 5178 |
|
5178 | 5179 |
|
| 5180 | +# --------------------------------------------------------------------------- |
| 5181 | +# _record_baseline_from_batch |
| 5182 | +# --------------------------------------------------------------------------- |
| 5183 | + |
| 5184 | + |
| 5185 | +class TestRecordBaselineFromBatch: |
| 5186 | + def setup_method(self): |
| 5187 | + self.client = _make_client() |
| 5188 | + self.client._initialize_class_members_from_config(_make_agent_config()) |
| 5189 | + |
| 5190 | + def _ctx(self, duration_ms=None, cost=None): |
| 5191 | + ctx = self.client._create_optimization_context(iteration=1, variables={}) |
| 5192 | + return dataclasses.replace(ctx, duration_ms=duration_ms, estimated_cost_usd=cost) |
| 5193 | + |
| 5194 | + def test_averages_duration_across_batch(self): |
| 5195 | + results = [self._ctx(duration_ms=1000), self._ctx(duration_ms=2000), self._ctx(duration_ms=3000)] |
| 5196 | + self.client._record_baseline_from_batch(results) |
| 5197 | + assert self.client._baseline_duration_ms == 2000.0 |
| 5198 | + |
| 5199 | + def test_averages_cost_across_batch(self): |
| 5200 | + results = [self._ctx(cost=0.01), self._ctx(cost=0.02), self._ctx(cost=0.03)] |
| 5201 | + self.client._record_baseline_from_batch(results) |
| 5202 | + assert abs(self.client._baseline_cost_usd - 0.02) < 1e-9 |
| 5203 | + |
| 5204 | + def test_skips_none_values_in_average(self): |
| 5205 | + results = [self._ctx(duration_ms=1000), self._ctx(duration_ms=None), self._ctx(duration_ms=3000)] |
| 5206 | + self.client._record_baseline_from_batch(results) |
| 5207 | + assert self.client._baseline_duration_ms == 2000.0 |
| 5208 | + |
| 5209 | + def test_noop_when_already_set(self): |
| 5210 | + self.client._baseline_duration_ms = 999.0 |
| 5211 | + results = [self._ctx(duration_ms=1000), self._ctx(duration_ms=2000)] |
| 5212 | + self.client._record_baseline_from_batch(results) |
| 5213 | + assert self.client._baseline_duration_ms == 999.0 |
| 5214 | + |
| 5215 | + def test_noop_on_empty_list(self): |
| 5216 | + self.client._record_baseline_from_batch([]) |
| 5217 | + assert self.client._baseline_duration_ms is None |
| 5218 | + assert self.client._baseline_cost_usd is None |
| 5219 | + |
| 5220 | + def test_noop_when_all_values_none(self): |
| 5221 | + results = [self._ctx(duration_ms=None), self._ctx(duration_ms=None)] |
| 5222 | + self.client._record_baseline_from_batch(results) |
| 5223 | + assert self.client._baseline_duration_ms is None |
| 5224 | + |
| 5225 | + |
5179 | 5226 | # --------------------------------------------------------------------------- |
5180 | 5227 | # _apply_duration_gate |
5181 | 5228 | # --------------------------------------------------------------------------- |
|
0 commit comments