Skip to content

Commit 486cf2e

Browse files
juliomachad0Copilot
andcommitted
ENH: __verify_trigger method and tests implementation. adaptation of out_of_rail_trigger to pass in __verify_trigger
Co-authored-by: Copilot <copilot@github.com>
1 parent a927248 commit 486cf2e

3 files changed

Lines changed: 112 additions & 4 deletions

File tree

rocketpy/simulation/events.py

Lines changed: 86 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import inspect
2+
from typing import get_type_hints
3+
import warnings
14
class Event:
25
# TODO: should "sensors" arg of the trigger function be a dictionary instead
36
# of a list? It would be more intuitive to access the sensors by name
@@ -89,11 +92,10 @@ def __init__(self, trigger, action, name, event_context=None):
8992
passed to subsequent calls. Defaults to an empty dictionary if not
9093
provided.
9194
"""
95+
self.name = name
9296
self.trigger = self.__verify_trigger(trigger)
9397
self.action = self.__verify_action(action)
94-
self.name = name
9598
self.event_context = event_context if event_context is not None else {}
96-
9799
# TODO: implement tracking for whether this event is currently enabled
98100
# or disabled. The disable_event flag from the action return value should
99101
# control whether this event continues to be checked for triggering.
@@ -122,8 +124,89 @@ def __verify_trigger(self, trigger):
122124
# 2. Return type annotation is bool or can be tested to return bool
123125
# 3. Consider allowing signature to be flexible (accepts **kwargs)
124126
# to accommodate user-defined custom event_context keys
127+
# verify if the return type is bool when annotated
128+
return_annotation = get_type_hints(trigger).get('return', None)
129+
if return_annotation is not None and return_annotation is not bool:
130+
raise ValueError(f"Trigger function {self.name} must return a boolean value.")
131+
# verify if the trigger function accepts **kwargs and therefore can
132+
# receive standard event arguments plus custom event_context keys
133+
s = inspect.signature(trigger)
134+
if not any(p.kind == inspect.Parameter.VAR_KEYWORD for p in s.parameters.values()):
135+
raise ValueError(
136+
f"Trigger function {self.name} must accept **kwargs to receive event context "
137+
f"and simulation state."
138+
)
139+
if any(p.kind == inspect.Parameter.POSITIONAL_ONLY for p in s.parameters.values()):
140+
raise ValueError(
141+
f"Trigger function {self.name} must accept keyword arguments; "
142+
"positional-only parameters are not supported."
143+
)
144+
# Helper function to generate dummy values based on type annotations
145+
# of parameters, allowing to test the function without real values
146+
def _placeholder_for_parameter(parameter):
147+
annotation = parameter.annotation
148+
if annotation is inspect.Parameter.empty:
149+
warnings.warn(f"Trigger function {self.name}: Test with parameters skipped due "
150+
f"to missing type annotation for parameter '{parameter.name}'. \n"
151+
f"Is highly recommended that parameters have type annotations "
152+
f"(var: type). Parameter '{parameter.name}' has no annotation.")
153+
skip_test = True
154+
return None, skip_test
155+
if annotation in (int, float):
156+
return 0, False
157+
if annotation is bool:
158+
return False, False
159+
if annotation is str:
160+
return "", False
161+
if annotation in (list, tuple, set, dict):
162+
return annotation(), False
163+
origin = getattr(annotation, "__origin__", None)
164+
if origin in (list, tuple, set, dict):
165+
return origin(), False
166+
return None, False
167+
# Build a dictionary with dummy values to test if function accepts **kwargs
168+
# Include an unexpected argument to validate the function doesn't complain
169+
test_kwargs = {"unexpected_kwarg": 123}
170+
skip_test = False
171+
# Iterate through function parameters to generate appropriate test values
172+
for name, parameter in s.parameters.items():
173+
if parameter.kind in (
174+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
175+
inspect.Parameter.KEYWORD_ONLY,
176+
):
177+
if parameter.default is inspect.Parameter.empty:
178+
annotation = parameter.annotation
179+
if annotation in (list, tuple, set, dict):
180+
skip_test = True
181+
elif hasattr(annotation, "__origin__") and getattr(annotation, "__origin__", None) in (list, tuple, set, dict):
182+
skip_test = True
183+
else:
184+
test_kwargs[name], skip_test = _placeholder_for_parameter(parameter)
185+
# Execute the trigger function with test values to validate compatibility
186+
# If TypeError occurs, the function doesn't properly accept **kwargs
187+
if not skip_test:
188+
try:
189+
trigger(**test_kwargs)
190+
except TypeError as e:
191+
raise ValueError(
192+
f"Trigger function {self.name} must accept arbitrary kwargs without raising "
193+
"a TypeError."
194+
) from e
195+
except Exception as e:
196+
raise ValueError(
197+
f"Trigger function {self.name} must accept arbitrary kwargs without raising "
198+
f"an error: {e}"
199+
) from e
200+
else:
201+
# Test was skipped due to complex types; warn user to validate manually
202+
warnings.warn(
203+
f"Trigger function {self.name}: Test with parameters "
204+
f"skipped for parameters with complex types "
205+
f"(list, tuple, set, dict). Ensure the function handles "
206+
f"arbitrary inputs gracefully."
207+
)
125208
return trigger
126-
209+
127210
def __verify_action(self, action):
128211
"""Verifies that the action function is valid.
129212

rocketpy/simulation/flight.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -619,7 +619,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-statements
619619
self.ode_solver = ode_solver
620620

621621
# Events
622-
def out_of_rail_trigger(state):
622+
def out_of_rail_trigger(state, **kwargs) -> bool:
623623
return (
624624
state[0] ** 2 + state[1] ** 2 + (state[2] - self.env.elevation) ** 2
625625
>= self.effective_1rl**2
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import pytest
2+
3+
from rocketpy.simulation.events import Event
4+
5+
6+
def test_verify_trigger_accepts_required_args_with_kwargs():
7+
def trigger(a: int, b: float, **kwargs) -> bool:
8+
return True
9+
10+
def action(**kwargs):
11+
return None
12+
13+
event = Event(trigger=trigger, action=action, name="test")
14+
assert event.trigger is trigger
15+
16+
17+
def test_verify_trigger_rejects_missing_kwargs():
18+
def trigger(a, b) -> bool:
19+
return True
20+
21+
def action(**kwargs):
22+
return None
23+
24+
with pytest.raises(ValueError, match=r"must accept \*\*kwargs"):
25+
Event(trigger=trigger, action=action, name="test")

0 commit comments

Comments
 (0)