Skip to content

Commit 27abdea

Browse files
committed
Replace v0.2.1 units validators with a declarative table
Adds UnitsValidatedModel to base.py: a model base whose subclasses declare a UNIT_RULES table (field name -> allowed units) instead of hand-writing a check_units method. A single inherited validator merges the rules declared across the MRO, so each class only states the fields it introduces and shared fields are validated once. Migrates the six v0.2.1 unit validators (the MediaPayloadStats hierarchy plus AggregateRequests/AggregateThroughput) to this mechanism. The two aggregate classes multiply-inherit UnitsValidatedModel alongside their v0.2 base; the shared validator is named distinctly from v0.2's check_units so both co-run rather than one shadowing the other. Behavior is unchanged and covered by the existing compat and multimodal guardrail tests; the only schema-artifact diff is an expanded AggregateRequests description. Signed-off-by: Brendan Slabe <slabe@google.com>
1 parent 0aa7d20 commit 27abdea

3 files changed

Lines changed: 78 additions & 77 deletions

File tree

llmdbenchmark/analysis/benchmark_report/base.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44

55
import json
66
from enum import StrEnum, auto
7-
from typing import Any
7+
from typing import Any, ClassVar
88

9-
from pydantic import BaseModel
9+
from pydantic import BaseModel, model_validator
1010
import yaml
1111

1212
###############################################################################
@@ -161,6 +161,46 @@ class Units(StrEnum):
161161
UNITS_MEDIA_THROUGHPUT = [Units.IMAGE_PER_S, Units.VIDEO_PER_S, Units.AUDIO_PER_S]
162162
UNITS_POWER = [Units.WATTS]
163163

164+
165+
###############################################################################
166+
# Declarative units validation
167+
###############################################################################
168+
169+
170+
class UnitsValidatedModel(BaseModel):
171+
"""Base model that validates ``Statistics`` field units declaratively.
172+
173+
Instead of hand-writing a ``check_units`` method per class, a subclass sets
174+
``UNIT_RULES``, mapping a field name to the list of units allowed for that
175+
field's ``Statistics.units``. A single inherited validator checks every
176+
rule declared anywhere in the class's MRO, so each subclass only declares
177+
the fields it introduces; ``None`` (unset Optional) fields are skipped.
178+
179+
The validator is named distinctly from any ``check_units`` so it co-runs
180+
with, rather than shadows, a hand-written validator inherited from another
181+
base under multiple inheritance.
182+
"""
183+
184+
# field name -> allowed Units. Merged across the MRO at validation time.
185+
UNIT_RULES: ClassVar[dict[str, list[Units]]] = {}
186+
187+
@model_validator(mode="after")
188+
def validate_declared_units(self):
189+
merged: dict[str, list[Units]] = {}
190+
# Most-derived first, so a subclass rule overrides a base rule.
191+
for klass in type(self).__mro__:
192+
for field, allowed in klass.__dict__.get("UNIT_RULES", {}).items():
193+
merged.setdefault(field, allowed)
194+
for field, allowed in merged.items():
195+
stat = getattr(self, field, None)
196+
if stat is not None and stat.units not in allowed:
197+
raise ValueError(
198+
f'Invalid units "{stat.units}" for "{field}", must be one'
199+
f" of: {' '.join(allowed)}"
200+
)
201+
return self
202+
203+
164204
###############################################################################
165205
# Base benchmark report class
166206
###############################################################################

llmdbenchmark/analysis/benchmark_report/br_v0_2_1_json_schema.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@
114114
},
115115
"AggregateRequests": {
116116
"additionalProperties": false,
117-
"description": "v0.2 request statistics, plus multi-modal payload details.",
117+
"description": "v0.2 request statistics, plus multi-modal payload details.\n\nInherits the v0.2 input/output-length unit checks and adds a declarative\nrule for the new request_size field.",
118118
"properties": {
119119
"total": {
120120
"description": "Total number of requests sent.",

llmdbenchmark/analysis/benchmark_report/schema_v0_2_1.py

Lines changed: 35 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,18 @@
1717
the PR description.
1818
"""
1919

20-
from pydantic import BaseModel, model_validator
20+
from typing import ClassVar
21+
22+
from pydantic import BaseModel
2123

2224
from .base import (
2325
UNITS_MEDIA_THROUGHPUT,
2426
UNITS_MEMORY,
2527
UNITS_QUANTITY,
2628
UNITS_RATIO,
2729
UNITS_TIME,
30+
Units,
31+
UnitsValidatedModel,
2832
)
2933
from .schema_v0_2 import (
3034
MODEL_CONFIG,
@@ -66,7 +70,7 @@
6670
###############################################################################
6771

6872

69-
class MediaPayloadStats(BaseModel):
73+
class MediaPayloadStats(UnitsValidatedModel):
7074
"""Payload statistics shared by every media modality.
7175
7276
All fields are distributions over the individual media instances the client
@@ -75,50 +79,32 @@ class MediaPayloadStats(BaseModel):
7579

7680
model_config = MODEL_CONFIG.copy()
7781

82+
UNIT_RULES: ClassVar[dict[str, list[Units]]] = {
83+
"count": UNITS_QUANTITY,
84+
"bytes": UNITS_MEMORY,
85+
}
86+
7887
count: Statistics | None = None
7988
"""Number of media instances of this modality per request."""
8089
bytes: Statistics | None = None
8190
"""Encoded size per media instance."""
8291

83-
@model_validator(mode="after")
84-
def check_media_units(self):
85-
if self.count and self.count.units not in UNITS_QUANTITY:
86-
raise ValueError(
87-
f'Invalid units "{self.count.units}", must be one of:'
88-
f" {' '.join(UNITS_QUANTITY)}"
89-
)
90-
if self.bytes and self.bytes.units not in UNITS_MEMORY:
91-
raise ValueError(
92-
f'Invalid units "{self.bytes.units}", must be one of:'
93-
f" {' '.join(UNITS_MEMORY)}"
94-
)
95-
return self
96-
9792

9893
class VisualPayloadStats(MediaPayloadStats):
9994
"""Payload statistics common to pixel-based modalities (image and video)."""
10095

10196
model_config = MODEL_CONFIG.copy()
10297

98+
UNIT_RULES: ClassVar[dict[str, list[Units]]] = {
99+
"pixels": UNITS_QUANTITY,
100+
"aspect_ratio": UNITS_RATIO,
101+
}
102+
103103
pixels: Statistics | None = None
104104
"""Pixel count per media instance (height x width, summed over frames)."""
105105
aspect_ratio: Statistics | None = None
106106
"""Aspect ratio (width / height) per media instance."""
107107

108-
@model_validator(mode="after")
109-
def check_visual_units(self):
110-
if self.pixels and self.pixels.units not in UNITS_QUANTITY:
111-
raise ValueError(
112-
f'Invalid units "{self.pixels.units}", must be one of:'
113-
f" {' '.join(UNITS_QUANTITY)}"
114-
)
115-
if self.aspect_ratio and self.aspect_ratio.units not in UNITS_RATIO:
116-
raise ValueError(
117-
f'Invalid units "{self.aspect_ratio.units}", must be one of:'
118-
f" {' '.join(UNITS_RATIO)}"
119-
)
120-
return self
121-
122108

123109
class ImagePayloadStats(VisualPayloadStats):
124110
"""Image payload statistics."""
@@ -131,36 +117,22 @@ class VideoPayloadStats(VisualPayloadStats):
131117

132118
model_config = MODEL_CONFIG.copy()
133119

120+
UNIT_RULES: ClassVar[dict[str, list[Units]]] = {"frames": UNITS_QUANTITY}
121+
134122
frames: Statistics | None = None
135123
"""Number of frames per video instance."""
136124

137-
@model_validator(mode="after")
138-
def check_video_units(self):
139-
if self.frames and self.frames.units not in UNITS_QUANTITY:
140-
raise ValueError(
141-
f'Invalid units "{self.frames.units}", must be one of:'
142-
f" {' '.join(UNITS_QUANTITY)}"
143-
)
144-
return self
145-
146125

147126
class AudioPayloadStats(MediaPayloadStats):
148127
"""Audio payload statistics."""
149128

150129
model_config = MODEL_CONFIG.copy()
151130

131+
UNIT_RULES: ClassVar[dict[str, list[Units]]] = {"seconds": UNITS_TIME}
132+
152133
seconds: Statistics | None = None
153134
"""Duration per audio instance."""
154135

155-
@model_validator(mode="after")
156-
def check_audio_units(self):
157-
if self.seconds and self.seconds.units not in UNITS_TIME:
158-
raise ValueError(
159-
f'Invalid units "{self.seconds.units}", must be one of:'
160-
f" {' '.join(UNITS_TIME)}"
161-
)
162-
return self
163-
164136

165137
class MultiModalRequests(BaseModel):
166138
"""Per-modality request payload statistics for multi-modal workloads."""
@@ -180,52 +152,41 @@ class MultiModalRequests(BaseModel):
180152
###############################################################################
181153

182154

183-
class AggregateRequests(AggregateRequestsV02):
184-
"""v0.2 request statistics, plus multi-modal payload details."""
155+
class AggregateRequests(AggregateRequestsV02, UnitsValidatedModel):
156+
"""v0.2 request statistics, plus multi-modal payload details.
157+
158+
Inherits the v0.2 input/output-length unit checks and adds a declarative
159+
rule for the new request_size field.
160+
"""
185161

186162
model_config = MODEL_CONFIG.copy()
187163

164+
UNIT_RULES: ClassVar[dict[str, list[Units]]] = {"request_size": UNITS_MEMORY}
165+
188166
request_size: Statistics | None = None
189167
"""Total encoded request size, including all media payloads."""
190168
multimodal: MultiModalRequests | None = None
191169
"""Per-modality payload statistics."""
192170

193-
@model_validator(mode="after")
194-
def check_request_size_units(self):
195-
if self.request_size and self.request_size.units not in UNITS_MEMORY:
196-
raise ValueError(
197-
f'Invalid units "{self.request_size.units}", must be one of:'
198-
f" {' '.join(UNITS_MEMORY)}"
199-
)
200-
return self
201171

202-
203-
class AggregateThroughput(AggregateThroughputV02):
172+
class AggregateThroughput(AggregateThroughputV02, UnitsValidatedModel):
204173
"""v0.2 throughput metrics, plus per-modality payload rates."""
205174

206175
model_config = MODEL_CONFIG.copy()
207176

177+
UNIT_RULES: ClassVar[dict[str, list[Units]]] = {
178+
"image_rate": UNITS_MEDIA_THROUGHPUT,
179+
"video_rate": UNITS_MEDIA_THROUGHPUT,
180+
"audio_rate": UNITS_MEDIA_THROUGHPUT,
181+
}
182+
208183
image_rate: Statistics | None = None
209184
"""Image delivery rate."""
210185
video_rate: Statistics | None = None
211186
"""Video delivery rate."""
212187
audio_rate: Statistics | None = None
213188
"""Audio delivery rate."""
214189

215-
@model_validator(mode="after")
216-
def check_media_rate_units(self):
217-
for name, stat in (
218-
("image_rate", self.image_rate),
219-
("video_rate", self.video_rate),
220-
("audio_rate", self.audio_rate),
221-
):
222-
if stat and stat.units not in UNITS_MEDIA_THROUGHPUT:
223-
raise ValueError(
224-
f'Invalid units "{stat.units}" for {name}, must be one of:'
225-
f" {' '.join(UNITS_MEDIA_THROUGHPUT)}"
226-
)
227-
return self
228-
229190

230191
###############################################################################
231192
# Containment shims: re-thread the extended aggregates up to a new report root.

0 commit comments

Comments
 (0)