Skip to content

Commit 05ebb10

Browse files
CopilotaZira371
andcommitted
Address review feedback: add unsorted axis warning, flatten_for_compatibility parameter, and early return guard clause
Co-authored-by: aZira371 <99824864+aZira371@users.noreply.github.com>
1 parent fe2052b commit 05ebb10

3 files changed

Lines changed: 188 additions & 86 deletions

File tree

rocketpy/mathutils/function.py

Lines changed: 60 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4073,6 +4073,7 @@ def from_grid(
40734073
outputs=None,
40744074
interpolation="regular_grid",
40754075
extrapolation="constant",
4076+
flatten_for_compatibility=True,
40764077
**kwargs,
40774078
): # pylint: disable=too-many-statements #TODO: Refactor this method into smaller methods
40784079
"""Creates a Function from N-dimensional grid data.
@@ -4115,6 +4116,14 @@ def from_grid(
41154116
41164117
If an unsupported extrapolation value is supplied a ``ValueError``
41174118
is raised.
4119+
flatten_for_compatibility : bool, optional
4120+
If True (default), creates flattened ``_domain``, ``_image``, and
4121+
``source`` arrays for backward compatibility with existing Function
4122+
methods and serialization. For large N-dimensional grids (e.g.,
4123+
100x100x100 points), this requires O(n^d) additional memory where n
4124+
is the typical axis length and d is the number of dimensions.
4125+
Set to False to skip this flattening and reduce memory usage if
4126+
compatibility with legacy code paths is not required.
41184127
**kwargs : dict, optional
41194128
Additional arguments passed to the Function constructor.
41204129
@@ -4174,13 +4183,21 @@ def from_grid(
41744183
f"({grid_data.ndim})"
41754184
)
41764185

4177-
# Check each axis matches corresponding grid dimension
4186+
# Check each axis matches corresponding grid dimension and is sorted
41784187
for i, axis in enumerate(axes):
41794188
if len(axis) != grid_data.shape[i]:
41804189
raise ValueError(
41814190
f"Axis {i} has {len(axis)} points but grid dimension {i} "
41824191
f"has {grid_data.shape[i]} points"
41834192
)
4193+
# Check if axis is sorted in ascending order
4194+
if not np.all(np.diff(axis) > 0):
4195+
warnings.warn(
4196+
f"Axis {i} is not strictly sorted in ascending order. "
4197+
"RegularGridInterpolator requires sorted axes. "
4198+
"This may cause unexpected interpolation results.",
4199+
UserWarning,
4200+
)
41844201

41854202
# Set default inputs if not provided
41864203
if inputs is None:
@@ -4217,15 +4234,21 @@ def from_grid(
42174234
fill_value=None, # Linear extrapolation (will be overridden by manual handling)
42184235
)
42194236

4220-
# Create placeholder domain and image for compatibility
4221-
# This flattens the grid for any code expecting these attributes
4222-
mesh = np.meshgrid(*axes, indexing="ij")
4223-
domain_points = np.column_stack([m.ravel() for m in mesh])
4224-
func._domain = domain_points
4225-
func._image = grid_data.ravel()
4226-
4227-
# Set source as flattened data array (for compatibility with serialization, etc.)
4228-
func.source = np.column_stack([domain_points, func._image])
4237+
# Create placeholder domain and image for compatibility.
4238+
# For large grids this requires O(n^d) memory; set flatten_for_compatibility=False
4239+
# to skip this if legacy code compatibility is not required.
4240+
if flatten_for_compatibility:
4241+
mesh = np.meshgrid(*axes, indexing="ij")
4242+
domain_points = np.column_stack([m.ravel() for m in mesh])
4243+
func._domain = domain_points
4244+
func._image = grid_data.ravel()
4245+
# Set source as flattened data array (for compatibility with serialization)
4246+
func.source = np.column_stack([domain_points, func._image])
4247+
else:
4248+
# Minimal placeholders - grid interpolator is the primary data source
4249+
func._domain = None
4250+
func._image = None
4251+
func.source = None
42294252

42304253
# Initialize basic attributes
42314254
func.__inputs__ = inputs
@@ -4241,21 +4264,29 @@ def from_grid(
42414264
# Set basic array attributes for compatibility
42424265
func.x_array = axes[0]
42434266
func.x_initial, func.x_final = axes[0][0], axes[0][-1]
4244-
# For grid-based (N-D) functions, a 1-D `y_array` is not a meaningful
4245-
# representation of the function values. Some legacy code paths and
4246-
# serialization expect a `y_array` attribute to exist, so provide the
4247-
# full flattened image for compatibility rather than a truncated slice.
4248-
# Callers should avoid relying on `y_array` for multidimensional
4249-
# Functions; use the interpolator / `get_value_opt` instead.
4250-
func.y_array = func._image
4251-
# Use the global min/max of the flattened image as a sensible
4252-
# `y_initial`/`y_final` for compatibility with code that inspects
4253-
# scalar bounds. These describe the image range, not an ordering
4254-
# along any particular axis.
4255-
func.y_initial, func.y_final = (
4256-
float(func._image.min()),
4257-
float(func._image.max()),
4258-
)
4267+
if flatten_for_compatibility:
4268+
# For grid-based (N-D) functions, a 1-D `y_array` is not a meaningful
4269+
# representation of the function values. Some legacy code paths and
4270+
# serialization expect a `y_array` attribute to exist, so provide the
4271+
# full flattened image for compatibility rather than a truncated slice.
4272+
# Callers should avoid relying on `y_array` for multidimensional
4273+
# Functions; use the interpolator / `get_value_opt` instead.
4274+
func.y_array = func._image
4275+
# Use the global min/max of the flattened image as a sensible
4276+
# `y_initial`/`y_final` for compatibility with code that inspects
4277+
# scalar bounds. These describe the image range, not an ordering
4278+
# along any particular axis.
4279+
func.y_initial, func.y_final = (
4280+
float(func._image.min()),
4281+
float(func._image.max()),
4282+
)
4283+
else:
4284+
# Minimal placeholders when flattening is disabled
4285+
func.y_array = None
4286+
func.y_initial, func.y_final = (
4287+
float(grid_data.min()),
4288+
float(grid_data.max()),
4289+
)
42594290
if len(axes) > 2:
42604291
func.z_array = axes[2]
42614292
func.z_initial, func.z_final = axes[2][0], axes[2][-1]
@@ -4265,7 +4296,10 @@ def from_grid(
42654296

42664297
# Set interpolation and extrapolation functions
42674298
func.__set_interpolation_func()
4268-
func.__set_extrapolation_func()
4299+
# Only set extrapolation function if we have flattened data, otherwise
4300+
# extrapolation is handled by __get_value_opt_grid directly
4301+
if flatten_for_compatibility:
4302+
func.__set_extrapolation_func()
42694303

42704304
# Set inputs and outputs properly
42714305
func.set_inputs(inputs)

rocketpy/simulation/flight.py

Lines changed: 58 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1398,72 +1398,70 @@ def __get_drag_coefficient(self, drag_function, mach, z, freestream_velocity_bod
13981398
float
13991399
Drag coefficient value
14001400
"""
1401-
# Check if drag function is multi-dimensional using Function API
1402-
if isinstance(drag_function, Function) and getattr(
1401+
# Early return for 1D drag functions (only mach number)
1402+
if not isinstance(drag_function, Function) or not getattr(
14031403
drag_function, "is_multidimensional", False
14041404
):
1405-
# Multi-dimensional drag function - calculate additional parameters
1406-
1407-
# Calculate Reynolds number
1408-
# Re = rho * V * L / mu
1409-
# where L is characteristic length (rocket diameter)
1410-
rho = self.env.density.get_value_opt(z)
1411-
mu = self.env.dynamic_viscosity.get_value_opt(z)
1412-
freestream_speed = np.linalg.norm(freestream_velocity_body)
1413-
characteristic_length = 2 * self.rocket.radius # Diameter
1414-
# Defensive: avoid division by zero or non-finite viscosity values.
1415-
# Use a small epsilon fallback if `mu` is zero, negative, NaN or infinite.
1416-
try:
1417-
mu_val = float(mu)
1418-
except (TypeError, ValueError, OverflowError):
1419-
# Only catch errors related to invalid numeric conversion.
1420-
# Avoid catching broad Exception to satisfy linters and
1421-
# allow other unexpected errors to surface.
1422-
mu_val = 0.0
1423-
if not np.isfinite(mu_val) or mu_val <= 0.0:
1424-
mu_safe = 1e-10
1425-
else:
1426-
mu_safe = mu_val
1405+
return drag_function.get_value_opt(mach)
14271406

1428-
reynolds = rho * freestream_speed * characteristic_length / mu_safe
1407+
# Multi-dimensional drag function - calculate additional parameters
14291408

1430-
# Calculate angle of attack
1431-
# Angle between freestream velocity and rocket axis (z-axis in body frame)
1432-
# The z component of freestream velocity in body frame
1433-
if hasattr(freestream_velocity_body, "z"):
1434-
stream_vz_b = -freestream_velocity_body.z
1435-
else:
1436-
stream_vz_b = -freestream_velocity_body[2]
1437-
1438-
# Normalize and calculate angle
1439-
if freestream_speed > 1e-6:
1440-
cos_alpha = stream_vz_b / freestream_speed
1441-
# Clamp to [-1, 1] to avoid numerical issues
1442-
cos_alpha = np.clip(cos_alpha, -1.0, 1.0)
1443-
alpha_rad = np.arccos(cos_alpha)
1444-
alpha_deg = np.rad2deg(alpha_rad)
1445-
else:
1446-
alpha_deg = 0.0
1447-
1448-
# Determine which parameters to pass based on input names
1449-
input_names = [name.lower() for name in drag_function.__inputs__]
1450-
args = []
1451-
1452-
for name in input_names:
1453-
if "mach" in name or name == "m":
1454-
args.append(mach)
1455-
elif "reynolds" in name or name == "re":
1456-
args.append(reynolds)
1457-
elif "alpha" in name or name == "a" or "attack" in name:
1458-
args.append(alpha_deg)
1459-
else:
1460-
# Unknown parameter, default to mach
1461-
args.append(mach)
1409+
# Calculate Reynolds number: Re = rho * V * L / mu
1410+
# where L is characteristic length (rocket diameter)
1411+
rho = self.env.density.get_value_opt(z)
1412+
mu = self.env.dynamic_viscosity.get_value_opt(z)
1413+
freestream_speed = np.linalg.norm(freestream_velocity_body)
1414+
characteristic_length = 2 * self.rocket.radius # Diameter
1415+
# Defensive: avoid division by zero or non-finite viscosity values.
1416+
# Use a small epsilon fallback if `mu` is zero, negative, NaN or infinite.
1417+
try:
1418+
mu_val = float(mu)
1419+
except (TypeError, ValueError, OverflowError):
1420+
# Only catch errors related to invalid numeric conversion.
1421+
# Avoid catching broad Exception to satisfy linters and
1422+
# allow other unexpected errors to surface.
1423+
mu_val = 0.0
1424+
if not np.isfinite(mu_val) or mu_val <= 0.0:
1425+
mu_safe = 1e-10
1426+
else:
1427+
mu_safe = mu_val
14621428

1463-
return drag_function.get_value_opt(*args)
1429+
reynolds = rho * freestream_speed * characteristic_length / mu_safe
1430+
1431+
# Calculate angle of attack
1432+
# Angle between freestream velocity and rocket axis (z-axis in body frame)
1433+
# The z component of freestream velocity in body frame
1434+
if hasattr(freestream_velocity_body, "z"):
1435+
stream_vz_b = -freestream_velocity_body.z
14641436
else:
1465-
# 1D drag function - only mach number
1466-
return drag_function.get_value_opt(mach)
1437+
stream_vz_b = -freestream_velocity_body[2]
1438+
1439+
# Normalize and calculate angle
1440+
if freestream_speed > 1e-6:
1441+
cos_alpha = stream_vz_b / freestream_speed
1442+
# Clamp to [-1, 1] to avoid numerical issues
1443+
cos_alpha = np.clip(cos_alpha, -1.0, 1.0)
1444+
alpha_rad = np.arccos(cos_alpha)
1445+
alpha_deg = np.rad2deg(alpha_rad)
1446+
else:
1447+
alpha_deg = 0.0
1448+
1449+
# Determine which parameters to pass based on input names
1450+
input_names = [name.lower() for name in drag_function.__inputs__]
1451+
args = []
1452+
1453+
for name in input_names:
1454+
if "mach" in name or name == "m":
1455+
args.append(mach)
1456+
elif "reynolds" in name or name == "re":
1457+
args.append(reynolds)
1458+
elif "alpha" in name or name == "a" or "attack" in name:
1459+
args.append(alpha_deg)
1460+
else:
1461+
# Unknown parameter, default to mach
1462+
args.append(mach)
1463+
1464+
return drag_function.get_value_opt(*args)
14671465

14681466
def udot_rail1(self, t, u, post_processing=False):
14691467
"""Calculates derivative of u state vector with respect to time

tests/unit/mathutils/test_function_grid.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,3 +280,73 @@ def test_shepard_fallback_at_exact_data_points():
280280
assert func(1, 0) == 20
281281
assert func(0, 1) == 30
282282
assert func(1, 1) == 40
283+
284+
285+
def test_from_grid_unsorted_axis_warns():
286+
"""Test that from_grid warns when axes are not sorted in ascending order."""
287+
y_data = np.array([0.0, 1.0, 4.0])
288+
289+
# Test with unsorted axis (descending order)
290+
unsorted_axis = np.array([2.0, 1.0, 0.0])
291+
292+
with warnings.catch_warnings(record=True) as w:
293+
warnings.simplefilter("always")
294+
Function.from_grid(y_data, [unsorted_axis], inputs=["x"], outputs="y")
295+
296+
# Check that a warning was issued
297+
assert len(w) == 1
298+
assert "not strictly sorted in ascending order" in str(w[0].message)
299+
300+
301+
def test_from_grid_repeated_values_warns():
302+
"""Test that from_grid warns when axes have repeated values.
303+
304+
Note: RegularGridInterpolator requires strictly ascending or descending
305+
axes. Repeated values will cause scipy to raise a ValueError after our
306+
warning is issued.
307+
"""
308+
y_data = np.array([0.0, 1.0, 4.0])
309+
310+
# Test with repeated values (not strictly ascending)
311+
repeated_axis = np.array([0.0, 1.0, 1.0])
312+
313+
with warnings.catch_warnings(record=True) as w:
314+
warnings.simplefilter("always")
315+
# Scipy will raise ValueError after our warning, so we expect both
316+
try:
317+
Function.from_grid(y_data, [repeated_axis], inputs=["x"], outputs="y")
318+
except ValueError as e:
319+
# scipy raises this error for non-strictly-sorted axes
320+
assert "strictly ascending" in str(e).lower() or "dimension 0" in str(e)
321+
322+
# Check that a warning was issued before the error
323+
assert len(w) == 1
324+
assert "not strictly sorted in ascending order" in str(w[0].message)
325+
326+
327+
def test_from_grid_flatten_for_compatibility_false():
328+
"""Test that flatten_for_compatibility=False skips flattening."""
329+
x = np.array([0.0, 1.0, 2.0])
330+
y = np.array([0.0, 1.0])
331+
332+
X, Y = np.meshgrid(x, y, indexing="ij")
333+
z_data = X + Y
334+
335+
func = Function.from_grid(
336+
z_data,
337+
[x, y],
338+
inputs=["x", "y"],
339+
outputs="z",
340+
flatten_for_compatibility=False,
341+
)
342+
343+
# Check that flattened attributes are None
344+
assert func._domain is None
345+
assert func._image is None
346+
assert func.source is None
347+
assert func.y_array is None
348+
349+
# But the function should still work correctly
350+
assert func(0.0, 0.0) == 0.0
351+
assert func(1.0, 1.0) == 2.0
352+
assert func(2.0, 1.0) == 3.0

0 commit comments

Comments
 (0)