|
3 | 3 | from __future__ import annotations |
4 | 4 |
|
5 | 5 | from dataclasses import dataclass |
| 6 | +from typing import Any |
6 | 7 |
|
7 | 8 | import numpy as np |
8 | 9 | import pandas as pd |
|
11 | 12 | FilterOperator, |
12 | 13 | TargetAggregation, |
13 | 14 | TargetFilter, |
| 15 | + TargetProvider, |
| 16 | + TargetQuery, |
14 | 17 | TargetSpec, |
15 | 18 | ) |
16 | 19 |
|
|
19 | 22 | TargetLevel, |
20 | 23 | TargetRegistry, |
21 | 24 | get_registry, |
| 25 | + target_available_in_cps, |
22 | 26 | target_category, |
| 27 | + target_group_name, |
| 28 | + target_level, |
23 | 29 | target_requires_imputation, |
24 | 30 | ) |
25 | 31 |
|
@@ -59,10 +65,50 @@ def summary(self) -> str: |
59 | 65 | class CalibrationHarness: |
60 | 66 | """Harness for calibration experiments over one entity frame at a time.""" |
61 | 67 |
|
62 | | - def __init__(self, registry: TargetRegistry | None = None): |
63 | | - self.registry = registry or get_registry() |
| 68 | + def __init__( |
| 69 | + self, |
| 70 | + registry: TargetRegistry | None = None, |
| 71 | + *, |
| 72 | + target_provider: TargetProvider | None = None, |
| 73 | + ): |
| 74 | + if target_provider is None: |
| 75 | + self.registry = registry or get_registry() |
| 76 | + self.target_provider = self.registry |
| 77 | + else: |
| 78 | + self.registry = registry |
| 79 | + self.target_provider = target_provider |
64 | 80 | self._results: dict[str, CalibrationResult] = {} |
65 | 81 |
|
| 82 | + def select_targets( |
| 83 | + self, |
| 84 | + *, |
| 85 | + categories: list[TargetCategory] | None = None, |
| 86 | + levels: list[TargetLevel] | None = None, |
| 87 | + groups: list[str] | None = None, |
| 88 | + only_available: bool = False, |
| 89 | + entity: EntityType | str | None = None, |
| 90 | + period: int | str | None = None, |
| 91 | + provider_filters: dict[str, Any] | None = None, |
| 92 | + ) -> list[TargetSpec]: |
| 93 | + """Select canonical targets from the configured provider.""" |
| 94 | + query = TargetQuery( |
| 95 | + period=period, |
| 96 | + entity=entity, |
| 97 | + provider_filters=dict(provider_filters or {}), |
| 98 | + ) |
| 99 | + targets = self.target_provider.load_target_set(query).targets |
| 100 | + return [ |
| 101 | + target |
| 102 | + for target in targets |
| 103 | + if _matches_us_target_filters( |
| 104 | + target, |
| 105 | + categories=categories, |
| 106 | + levels=levels, |
| 107 | + groups=groups, |
| 108 | + only_available=only_available, |
| 109 | + ) |
| 110 | + ] |
| 111 | + |
66 | 112 | def get_target_vector( |
67 | 113 | self, |
68 | 114 | df: pd.DataFrame, |
@@ -202,15 +248,19 @@ def run_experiment( |
202 | 248 | groups: list[str] | None = None, |
203 | 249 | only_available: bool = False, |
204 | 250 | entity: EntityType | str | None = None, |
| 251 | + period: int | str | None = None, |
| 252 | + provider_filters: dict[str, Any] | None = None, |
205 | 253 | **calibrate_kwargs, |
206 | 254 | ) -> CalibrationResult: |
207 | 255 | """Run a calibration experiment over a filtered target subset.""" |
208 | | - selected = self.registry.select_targets( |
| 256 | + selected = self.select_targets( |
209 | 257 | categories=categories, |
210 | 258 | levels=levels, |
211 | 259 | groups=groups, |
212 | 260 | only_available=only_available, |
213 | 261 | entity=entity, |
| 262 | + period=period, |
| 263 | + provider_filters=provider_filters, |
214 | 264 | ) |
215 | 265 | selected = [ |
216 | 266 | target |
@@ -262,7 +312,7 @@ def print_target_coverage( |
262 | 312 | print("TARGET COVERAGE ANALYSIS") |
263 | 313 | print("=" * 70) |
264 | 314 |
|
265 | | - all_targets = self.registry.select_targets(entity=entity) |
| 315 | + all_targets = self.select_targets(entity=entity) |
266 | 316 | columns = set(df.columns) |
267 | 317 |
|
268 | 318 | available: list[TargetSpec] = [] |
@@ -388,6 +438,25 @@ def _weight_stats(weights: np.ndarray) -> dict[str, float]: |
388 | 438 | } |
389 | 439 |
|
390 | 440 |
|
| 441 | +def _matches_us_target_filters( |
| 442 | + target: TargetSpec, |
| 443 | + *, |
| 444 | + categories: list[TargetCategory] | None = None, |
| 445 | + levels: list[TargetLevel] | None = None, |
| 446 | + groups: list[str] | None = None, |
| 447 | + only_available: bool = False, |
| 448 | +) -> bool: |
| 449 | + if categories and target_category(target) not in categories: |
| 450 | + return False |
| 451 | + if levels and target_level(target) not in levels: |
| 452 | + return False |
| 453 | + if groups and target_group_name(target) not in groups: |
| 454 | + return False |
| 455 | + if only_available and not target_available_in_cps(target): |
| 456 | + return False |
| 457 | + return True |
| 458 | + |
| 459 | + |
391 | 460 | def _build_constraint_row(df: pd.DataFrame, spec: TargetSpec) -> np.ndarray: |
392 | 461 | if spec.aggregation is TargetAggregation.MEAN: |
393 | 462 | raise NotImplementedError("Mean targets are not supported by this harness") |
|
0 commit comments