Skip to content

Commit cacd567

Browse files
committed
Pydantic validation in the payload for events part 1
1 parent f977f2f commit cacd567

3 files changed

Lines changed: 112 additions & 5 deletions

File tree

src/asyncflow/schemas/event/injection.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ class End(BaseModel):
5353
]
5454
t_end: PositiveFloat # strictly > 0
5555

56-
class Event(BaseModel):
56+
class EventInjection(BaseModel):
5757
"""Definition of the input structure to define an event in the simulation"""
5858

5959
event_id: str
@@ -64,8 +64,8 @@ class Event(BaseModel):
6464
@model_validator(mode="after") # type: ignore[arg-type]
6565
def ensure_start_end_compatibility(
6666
cls, # noqa: N805
67-
model: "Event",
68-
) -> "Event":
67+
model: "EventInjection",
68+
) -> "EventInjection":
6969
"""
7070
Check the compatibility between Start and End both at level
7171
of time interval and kind
@@ -78,7 +78,7 @@ def ensure_start_end_compatibility(
7878

7979
expected = start_to_end[model.start.kind]
8080
if model.end.kind != expected:
81-
msg = (f"The event {model.event_id} must have"
81+
msg = (f"The event {model.event_id} must have"
8282
f"as value of kind in end {expected}")
8383
raise ValueError(msg)
8484

@@ -88,4 +88,6 @@ def ensure_start_end_compatibility(
8888
"must be smaller than the ending time")
8989
raise ValueError(msg)
9090

91+
return model
92+
9193

src/asyncflow/schemas/payload.py

Lines changed: 105 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
"""Definition of the full input for the simulation"""
22

3-
from pydantic import BaseModel
3+
from pydantic import BaseModel, field_validator, model_validator
44

5+
from asyncflow.schemas.event.injection import EventInjection
56
from asyncflow.schemas.settings.simulation import SimulationSettings
67
from asyncflow.schemas.topology.graph import TopologyGraph
78
from asyncflow.schemas.workload.rqs_generator import RqsGenerator
@@ -13,3 +14,106 @@ class SimulationPayload(BaseModel):
1314
rqs_input: RqsGenerator
1415
topology_graph: TopologyGraph
1516
sim_settings: SimulationSettings
17+
events: list[EventInjection] | None = None
18+
19+
@field_validator("events", mode="after")
20+
def ensure_event_id_is_unique(
21+
cls, # noqa: N805
22+
v: list[EventInjection] | None,
23+
) -> list[EventInjection] | None:
24+
"""Ensure the id uniqueness of the events id"""
25+
if v is None:
26+
return v
27+
28+
event_id = [event.event_id for event in v]
29+
set_event_id = set(event_id)
30+
31+
if len(event_id) != len(set_event_id):
32+
msg = "The id's representing different events must be unique"
33+
raise ValueError(msg)
34+
return v
35+
36+
@model_validator(mode="after") # type: ignore[arg-type]
37+
def ensure_components_ids_is_compatible(
38+
cls, # noqa: N805
39+
model: "SimulationPayload",
40+
) -> "SimulationPayload":
41+
"""
42+
Ensure the id related to the target component of the event
43+
exist
44+
"""
45+
if model.events is None:
46+
return model
47+
48+
server_list = model.topology_graph.nodes.servers
49+
edges_list = model.topology_graph.edges
50+
valid_ids = (
51+
{server.id for server in server_list}
52+
| {edge.id for edge in edges_list}
53+
)
54+
55+
for event in model.events:
56+
if event.target_id not in valid_ids:
57+
msg = (f"The target id {event.target_id} related to"
58+
f"the event {event.event_id} does not exist")
59+
raise ValueError(msg)
60+
61+
return model
62+
63+
@model_validator(mode="after") # type: ignore[arg-type]
64+
def ensure_event_time_inside_simulatioon_horizon(
65+
cls, # noqa: N805
66+
model: "SimulationPayload",
67+
) -> "SimulationPayload":
68+
"""
69+
The interval of time associated to each events must be
70+
included in the simulation horizon
71+
"""
72+
if model.events is None:
73+
return model
74+
75+
horizon = float(model.sim_settings.total_simulation_time)
76+
77+
for ev in model.events:
78+
t_start = ev.start.t_start
79+
t_end = ev.end.t_end
80+
81+
if t_start < 0.0:
82+
msg = (
83+
f"Event '{ev.event_id}': start time t_start={t_start:.6f} "
84+
"must be >= 0.0"
85+
)
86+
raise ValueError(msg)
87+
88+
if t_start > horizon:
89+
msg = (
90+
f"Event '{ev.event_id}': start time t_start={t_start:.6f} "
91+
f"exceeds simulation horizon T={horizon:.6f}"
92+
)
93+
raise ValueError(msg)
94+
95+
# t_end is PositiveFloat by schema, but still guard the horizon.
96+
if t_end > horizon:
97+
msg = (
98+
f"Event '{ev.event_id}': end time t_end={t_end:.6f} "
99+
f"exceeds simulation horizon T={horizon:.6f}"
100+
)
101+
raise ValueError(msg)
102+
103+
return model
104+
105+
@model_validator(mode="after") # type: ignore[arg-type]
106+
def ensure_compatibility_event_kind_target_id(
107+
cls, # noqa: N805
108+
model: "SimulationPayload",
109+
) -> "SimulationPayload":
110+
"""
111+
The kind of the event must be compatible with the target id
112+
type
113+
"""
114+
if model.events is None:
115+
return model
116+
117+
118+
return model
119+

tests/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,4 +150,5 @@ def payload_base(
150150
rqs_input=rqs_input,
151151
topology_graph=topology_minimal,
152152
sim_settings=sim_settings,
153+
153154
)

0 commit comments

Comments
 (0)