Skip to content

Commit 2575950

Browse files
committed
ENH: Pass Acceleration Data to Parachute Trigger Functions (RocketPy-Team#XXX)
* ENH: add u_dot parameter computation inside parachute trigger evaluation. * ENH: add acceleration_noise_function parameter to Flight class for realistic IMU simulation. * ENH: implement automatic detection of trigger signature to compute derivatives only when needed. * TST: add unit tests for parachute trigger with acceleration data and noise injection. * TST: add test runner for trigger acceleration validation without full test suite dependencies.
1 parent 9cf3dd4 commit 2575950

3 files changed

Lines changed: 378 additions & 19 deletions

File tree

rocketpy/simulation/flight.py

Lines changed: 118 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from functools import cached_property
66

77
import numpy as np
8+
import inspect
89
from scipy.integrate import BDF, DOP853, LSODA, RK23, RK45, OdeSolver, Radau
910

1011
from rocketpy.simulation.flight_data_exporter import FlightDataExporter
@@ -94,6 +95,12 @@ class Flight:
9495
function evaluation points and then interpolation is used to
9596
calculate them and feed the triggers. Can greatly improve run
9697
time in some cases.
98+
Note
99+
----
100+
Calculating the derivative `u_dot` inside parachute trigger checks will
101+
cause additional calls to the equations of motion (extra physics
102+
evaluations). This increases CPU cost but enables realistic avionics
103+
algorithms that rely on accelerometer data.
97104
Flight.terminate_on_apogee : bool
98105
Whether to terminate simulation when rocket reaches apogee.
99106
Flight.solver : scipy.integrate.LSODA
@@ -487,6 +494,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-statements
487494
name="Flight",
488495
equations_of_motion="standard",
489496
ode_solver="LSODA",
497+
acceleration_noise_function=None,
490498
):
491499
"""Run a trajectory simulation.
492500
@@ -600,6 +608,13 @@ def __init__( # pylint: disable=too-many-arguments,too-many-statements
600608
self.name = name
601609
self.equations_of_motion = equations_of_motion
602610
self.ode_solver = ode_solver
611+
# Function that returns accelerometer noise vector [ax_noise, ay_noise, az_noise]
612+
# It should be a callable that returns an array-like of length 3.
613+
self.acceleration_noise_function = (
614+
acceleration_noise_function
615+
if acceleration_noise_function is not None
616+
else (lambda: np.zeros(3))
617+
)
603618

604619
# Controller initialization
605620
self.__init_controllers()
@@ -739,11 +754,14 @@ def __simulate(self, verbose):
739754
) = self.__calculate_and_save_pressure_signals(
740755
parachute, node.t, self.y_sol[2]
741756
)
742-
if parachute.triggerfunc(
757+
if self._evaluate_parachute_trigger(
758+
parachute,
743759
noisy_pressure,
744760
height_above_ground_level,
745761
self.y_sol,
746762
self.sensors,
763+
phase.derivative,
764+
self.t,
747765
):
748766
# Remove parachute from flight parachutes
749767
self.parachutes.remove(parachute)
@@ -772,8 +790,7 @@ def __simulate(self, verbose):
772790
lambda self, parachute_porosity=parachute.porosity: setattr(
773791
self, "parachute_porosity", parachute_porosity
774792
),
775-
lambda self,
776-
added_mass_coefficient=parachute.added_mass_coefficient: setattr(
793+
lambda self, added_mass_coefficient=parachute.added_mass_coefficient: setattr(
777794
self,
778795
"parachute_added_mass_coefficient",
779796
added_mass_coefficient,
@@ -998,11 +1015,14 @@ def __simulate(self, verbose):
9981015
)
9991016

10001017
# Check for parachute trigger
1001-
if parachute.triggerfunc(
1018+
if self._evaluate_parachute_trigger(
1019+
parachute,
10021020
noisy_pressure,
10031021
height_above_ground_level,
10041022
overshootable_node.y_sol,
10051023
self.sensors,
1024+
phase.derivative,
1025+
overshootable_node.t,
10061026
):
10071027
# Remove parachute from flight parachutes
10081028
self.parachutes.remove(parachute)
@@ -1020,30 +1040,25 @@ def __simulate(self, verbose):
10201040
i += 1
10211041
# Create flight phase for time after inflation
10221042
callbacks = [
1023-
lambda self,
1024-
parachute_cd_s=parachute.cd_s: setattr(
1043+
lambda self, parachute_cd_s=parachute.cd_s: setattr(
10251044
self, "parachute_cd_s", parachute_cd_s
10261045
),
1027-
lambda self,
1028-
parachute_radius=parachute.radius: setattr(
1046+
lambda self, parachute_radius=parachute.radius: setattr(
10291047
self,
10301048
"parachute_radius",
10311049
parachute_radius,
10321050
),
1033-
lambda self,
1034-
parachute_height=parachute.height: setattr(
1051+
lambda self, parachute_height=parachute.height: setattr(
10351052
self,
10361053
"parachute_height",
10371054
parachute_height,
10381055
),
1039-
lambda self,
1040-
parachute_porosity=parachute.porosity: setattr(
1056+
lambda self, parachute_porosity=parachute.porosity: setattr(
10411057
self,
10421058
"parachute_porosity",
10431059
parachute_porosity,
10441060
),
1045-
lambda self,
1046-
added_mass_coefficient=parachute.added_mass_coefficient: setattr(
1061+
lambda self, added_mass_coefficient=parachute.added_mass_coefficient: setattr(
10471062
self,
10481063
"parachute_added_mass_coefficient",
10491064
added_mass_coefficient,
@@ -1124,6 +1139,88 @@ def __calculate_and_save_pressure_signals(self, parachute, t, z):
11241139

11251140
return noisy_pressure, height_above_ground_level
11261141

1142+
def _evaluate_parachute_trigger(
1143+
self, parachute, pressure, height, y, sensors, derivative_func, t
1144+
):
1145+
"""Evaluate parachute trigger, optionally computing u_dot (acceleration).
1146+
1147+
This helper preserves backward compatibility with existing trigger
1148+
signatures and will compute ``u_dot`` only if the original user
1149+
provided trigger function expects an acceleration argument (detected
1150+
by parameter name such as 'u_dot', 'udot', 'acc', or 'acceleration').
1151+
1152+
Parameters
1153+
----------
1154+
parachute : Parachute
1155+
Parachute object.
1156+
pressure : float
1157+
Noisy pressure value passed to trigger.
1158+
height : float
1159+
Height above ground level passed to trigger.
1160+
y : array
1161+
State vector at evaluation time.
1162+
sensors : list
1163+
Sensors list passed to trigger.
1164+
derivative_func : callable
1165+
Function to compute derivatives: derivative_func(t, y)
1166+
t : float
1167+
Time at which to evaluate derivatives.
1168+
1169+
Returns
1170+
-------
1171+
bool
1172+
True if trigger condition met, False otherwise.
1173+
"""
1174+
# If original trigger is not callable (e.g. numeric or 'apogee'),
1175+
# use the prepared wrapper in Parachute
1176+
trig_original = parachute.trigger
1177+
if not callable(trig_original):
1178+
return parachute.triggerfunc(pressure, height, y, sensors)
1179+
1180+
try:
1181+
sig = inspect.signature(trig_original)
1182+
params = list(sig.parameters.values())
1183+
except (ValueError, TypeError):
1184+
return parachute.triggerfunc(pressure, height, y, sensors)
1185+
1186+
# Detect whether the user-provided trigger expects acceleration
1187+
acc_names = {"u_dot", "udot", "acc", "acceleration"}
1188+
wants_u_dot = any(p.name in acc_names for p in params)
1189+
wants_sensors = any("sensor" in p.name for p in params)
1190+
1191+
if wants_u_dot:
1192+
# Compute derivative and add optional accelerometer noise
1193+
u_dot = np.array(derivative_func(t, y), dtype=float)
1194+
try:
1195+
noise = np.asarray(self.acceleration_noise_function())
1196+
if noise.size == 3:
1197+
# u_dot layout: [vx, vy, vz, ax, ay, az, ...]
1198+
u_dot[3:6] = u_dot[3:6] + noise
1199+
except Exception:
1200+
# If noise function fails, ignore and continue
1201+
pass
1202+
1203+
# Call user function according to detected signature
1204+
try:
1205+
if wants_sensors:
1206+
# common case: (p, h, y, sensors, u_dot)
1207+
return trig_original(pressure, height, y, sensors, u_dot)
1208+
# fallback by arg count
1209+
if len(params) == 4:
1210+
# could be (p, h, y, u_dot)
1211+
return trig_original(pressure, height, y, u_dot)
1212+
if len(params) == 5:
1213+
# could be (p, h, y, sensors, u_dot)
1214+
return trig_original(pressure, height, y, sensors, u_dot)
1215+
# try calling with u_dot as kwarg
1216+
return trig_original(pressure, height, y, u_dot=u_dot)
1217+
except TypeError:
1218+
# If calling the original fails, fallback to wrapper
1219+
return parachute.triggerfunc(pressure, height, y, sensors)
1220+
1221+
# Default: don't compute u_dot and use existing wrapper
1222+
return parachute.triggerfunc(pressure, height, y, sensors)
1223+
11271224
def __init_solution_monitors(self):
11281225
# Initialize solution monitors
11291226
self.out_of_rail_time = 0
@@ -1439,7 +1536,9 @@ def udot_rail2(self, t, u, post_processing=False): # pragma: no cover
14391536
# Hey! We will finish this function later, now we just can use u_dot
14401537
return self.u_dot_generalized(t, u, post_processing=post_processing)
14411538

1442-
def u_dot(self, t, u, post_processing=False): # pylint: disable=too-many-locals,too-many-statements
1539+
def u_dot(
1540+
self, t, u, post_processing=False
1541+
): # pylint: disable=too-many-locals,too-many-statements
14431542
"""Calculates derivative of u state vector with respect to time
14441543
when rocket is flying in 6 DOF motion during ascent out of rail
14451544
and descent without parachute.
@@ -1759,7 +1858,9 @@ def u_dot(self, t, u, post_processing=False): # pylint: disable=too-many-locals
17591858

17601859
return u_dot
17611860

1762-
def u_dot_generalized(self, t, u, post_processing=False): # pylint: disable=too-many-locals,too-many-statements
1861+
def u_dot_generalized(
1862+
self, t, u, post_processing=False
1863+
): # pylint: disable=too-many-locals,too-many-statements
17631864
"""Calculates derivative of u state vector with respect to time when the
17641865
rocket is flying in 6 DOF motion in space and significant mass variation
17651866
effects exist. Typical flight phases include powered ascent after launch
@@ -3571,9 +3672,7 @@ def add(self, flight_phase, index=None): # TODO: quite complex method
35713672
new_index = (
35723673
index - 1
35733674
if flight_phase.t < previous_phase.t
3574-
else index + 1
3575-
if flight_phase.t > next_phase.t
3576-
else index
3675+
else index + 1 if flight_phase.t > next_phase.t else index
35773676
)
35783677
flight_phase.t += adjust
35793678
self.add(flight_phase, new_index)
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
import numpy as np
2+
import traceback
3+
4+
from rocketpy.simulation.flight import Flight
5+
from rocketpy.rocket.parachute import Parachute
6+
7+
8+
def _test_trigger_receives_u_dot_and_noise():
9+
def derivative_func(t, y):
10+
return np.array([0, 0, 0, 1.0, 2.0, 3.0, 0, 0, 0, 0, 0, 0, 0])
11+
12+
recorded = {}
13+
14+
def user_trigger(p, h, y, u_dot):
15+
recorded["u_dot"] = np.array(u_dot)
16+
return True
17+
18+
parachute = Parachute(
19+
name="test",
20+
cd_s=1.0,
21+
trigger=user_trigger,
22+
sampling_rate=100,
23+
)
24+
25+
dummy = type("D", (), {})()
26+
dummy.acceleration_noise_function = lambda: np.array([0.1, -0.2, 0.3])
27+
28+
res = Flight._evaluate_parachute_trigger(
29+
dummy,
30+
parachute,
31+
pressure=0.0,
32+
height=10.0,
33+
y=np.zeros(13),
34+
sensors=[],
35+
derivative_func=derivative_func,
36+
t=0.0,
37+
)
38+
39+
assert res is True
40+
assert "u_dot" in recorded
41+
assert np.allclose(recorded["u_dot"][3:6], np.array([1.1, 1.8, 3.3]))
42+
43+
44+
def _test_trigger_with_sensors_and_u_dot():
45+
def derivative_func(t, y):
46+
return np.array([0, 0, 0, -1.0, -2.0, -3.0, 0, 0, 0, 0, 0, 0, 0])
47+
48+
recorded = {}
49+
50+
def user_trigger(p, h, y, sensors, u_dot):
51+
recorded["sensors"] = sensors
52+
recorded["u_dot"] = np.array(u_dot)
53+
return False
54+
55+
parachute = Parachute(
56+
name="test2",
57+
cd_s=1.0,
58+
trigger=user_trigger,
59+
sampling_rate=100,
60+
)
61+
62+
dummy = type("D", (), {})()
63+
dummy.acceleration_noise_function = lambda: np.array([0.0, 0.0, 0.0])
64+
65+
res = Flight._evaluate_parachute_trigger(
66+
dummy,
67+
parachute,
68+
pressure=0.0,
69+
height=5.0,
70+
y=np.zeros(13),
71+
sensors=["s1"],
72+
derivative_func=derivative_func,
73+
t=1.234,
74+
)
75+
76+
assert res is False
77+
assert recorded["sensors"] == ["s1"]
78+
assert np.allclose(recorded["u_dot"][3:6], np.array([-1.0, -2.0, -3.0]))
79+
80+
81+
def _test_legacy_trigger_does_not_compute_u_dot():
82+
def derivative_func(t, y):
83+
raise RuntimeError("derivative should not be called for legacy triggers")
84+
85+
called = {}
86+
87+
def legacy_trigger(p, h, y):
88+
called["ok"] = True
89+
return True
90+
91+
parachute = Parachute(
92+
name="legacy",
93+
cd_s=1.0,
94+
trigger=legacy_trigger,
95+
sampling_rate=100,
96+
)
97+
98+
dummy = type("D", (), {})()
99+
dummy.acceleration_noise_function = lambda: np.zeros(3)
100+
101+
res = Flight._evaluate_parachute_trigger(
102+
dummy,
103+
parachute,
104+
pressure=0.0,
105+
height=0.0,
106+
y=np.zeros(13),
107+
sensors=[],
108+
derivative_func=derivative_func,
109+
t=0.0,
110+
)
111+
112+
assert res is True
113+
assert called.get("ok", False) is True
114+
115+
116+
def run_all():
117+
tests = [
118+
_test_trigger_receives_u_dot_and_noise,
119+
_test_trigger_with_sensors_and_u_dot,
120+
_test_legacy_trigger_does_not_compute_u_dot,
121+
]
122+
failures = 0
123+
for t in tests:
124+
name = t.__name__
125+
try:
126+
t()
127+
print(f"[PASS] {name}")
128+
except Exception:
129+
failures += 1
130+
print(f"[FAIL] {name}")
131+
traceback.print_exc()
132+
if failures:
133+
print(f"{failures} test(s) failed")
134+
raise SystemExit(1)
135+
print("All tests passed")
136+
137+
138+
if __name__ == "__main__":
139+
run_all()

0 commit comments

Comments
 (0)