Skip to content

Commit da64562

Browse files
committed
ENH: Avoid circular imports
1 parent 340ca6f commit da64562

2 files changed

Lines changed: 125 additions & 51 deletions

File tree

rocketpy/mathutils/function.py

Lines changed: 111 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
carefully as it may impact all the rest of the project.
66
"""
77

8+
import base64
9+
import functools
810
import operator
911
import warnings
1012
from bisect import bisect_left
@@ -15,6 +17,7 @@
1517
from inspect import signature
1618
from pathlib import Path
1719

20+
import dill
1821
import matplotlib.pyplot as plt
1922
import numpy as np
2023
from scipy import integrate, linalg, optimize
@@ -26,7 +29,6 @@
2629
)
2730

2831
from rocketpy.plots.plot_helpers import show_or_save_plot
29-
from rocketpy.tools import deprecated, from_hex_decode, to_hex_encode
3032

3133
# Numpy 1.x compatibility,
3234
# TODO: remove these lines when all dependencies support numpy>=2.0.0
@@ -49,6 +51,46 @@
4951
EXTRAPOLATION_TYPES = {"zero": 0, "natural": 1, "constant": 2}
5052

5153

54+
def deprecated(reason=None, version=None, alternative=None):
55+
"""Decorator to mark functions or methods as deprecated.
56+
57+
This decorator issues a DeprecationWarning when the decorated function
58+
is called, indicating that it will be removed in future versions.
59+
"""
60+
61+
def decorator(func):
62+
@functools.wraps(func)
63+
def wrapper(*args, **kwargs):
64+
if reason:
65+
message = reason
66+
else:
67+
message = f"The function `{func.__name__}` is deprecated"
68+
69+
if version:
70+
message += f" and will be removed in {version}"
71+
72+
if alternative:
73+
message += f". Use `{alternative}` instead"
74+
75+
message += "."
76+
warnings.warn(message, DeprecationWarning, stacklevel=2)
77+
return func(*args, **kwargs)
78+
79+
return wrapper
80+
81+
return decorator
82+
83+
84+
def to_hex_encode(obj, encoder=base64.b85encode):
85+
"""Converts an object to hex representation using dill."""
86+
return encoder(dill.dumps(obj)).hex()
87+
88+
89+
def from_hex_decode(obj_bytes, decoder=base64.b85decode):
90+
"""Converts an object from hex representation using dill."""
91+
return dill.loads(decoder(bytes.fromhex(obj_bytes)))
92+
93+
5294
class SourceType(Enum):
5395
"""Enumeration of the source types for the Function class.
5496
The source can be either a callable or an array.
@@ -162,6 +204,74 @@ def __init__(
162204
self.set_outputs(self.__outputs__)
163205
self.set_title(self.title)
164206

207+
@classmethod
208+
def from_regular_grid_csv(
209+
cls, csv_source, variable_names, coeff_name, extrapolation
210+
):
211+
"""Create a regular-grid Function from CSV samples when possible.
212+
213+
Parameters
214+
----------
215+
csv_source : str
216+
Path to the CSV file.
217+
variable_names : list[str]
218+
Ordered independent variable names present in the CSV.
219+
coeff_name : str
220+
Name of the output coefficient.
221+
extrapolation : str
222+
Extrapolation method passed to the Function constructor.
223+
224+
Returns
225+
-------
226+
Function or None
227+
A ``Function`` configured with ``regular_grid`` interpolation when
228+
the CSV forms a strict Cartesian grid, otherwise ``None``.
229+
"""
230+
try:
231+
data = np.loadtxt(csv_source, delimiter=",", skiprows=1, dtype=np.float64)
232+
except (OSError, ValueError):
233+
return None
234+
235+
data = np.atleast_2d(data)
236+
expected_columns = len(variable_names) + 1
237+
if data.shape[1] != expected_columns:
238+
return None
239+
240+
coordinates = data[:, :-1]
241+
values = data[:, -1]
242+
243+
if np.unique(coordinates, axis=0).shape[0] != coordinates.shape[0]:
244+
return None
245+
246+
axes = [np.unique(coordinates[:, i]) for i in range(len(variable_names))]
247+
expected_size = int(np.prod([axis.size for axis in axes]))
248+
if expected_size != coordinates.shape[0]:
249+
return None
250+
251+
sorting_keys = [
252+
coordinates[:, i] for i in range(len(variable_names) - 1, -1, -1)
253+
]
254+
sorted_indices = np.lexsort(tuple(sorting_keys))
255+
sorted_coordinates = coordinates[sorted_indices]
256+
sorted_values = values[sorted_indices]
257+
258+
expected_coordinates = np.column_stack(
259+
[axis_values.ravel() for axis_values in np.meshgrid(*axes, indexing="ij")]
260+
)
261+
if not np.allclose(
262+
sorted_coordinates, expected_coordinates, rtol=0, atol=1e-12
263+
):
264+
return None
265+
266+
grid_data = sorted_values.reshape(tuple(axis.size for axis in axes))
267+
return cls(
268+
(axes, grid_data),
269+
inputs=variable_names,
270+
outputs=[coeff_name],
271+
interpolation="regular_grid",
272+
extrapolation=extrapolation,
273+
)
274+
165275
# Define all set methods
166276
def set_inputs(self, inputs):
167277
"""Set the name and number of the incoming arguments of the Function.

rocketpy/tools.py

Lines changed: 14 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -117,11 +117,6 @@ def tuple_handler(value):
117117
raise ValueError("value must be a list or tuple of length 1 or 2.")
118118

119119

120-
def _get_function_class():
121-
"""Return ``Function`` using lazy import to avoid cyclic imports."""
122-
return importlib.import_module("rocketpy.mathutils.function").Function
123-
124-
125120
def create_regular_grid_function(
126121
csv_source,
127122
variable_names,
@@ -147,44 +142,13 @@ def create_regular_grid_function(
147142
A ``Function`` configured with ``regular_grid`` interpolation when the
148143
CSV data forms a strict Cartesian grid, otherwise ``None``.
149144
"""
150-
function_class = _get_function_class()
151-
152-
data = np.loadtxt(csv_source, delimiter=",", skiprows=1, dtype=float)
153-
154-
data = np.atleast_2d(data)
155-
expected_columns = len(variable_names) + 1
156-
if data.shape[1] != expected_columns:
157-
return None
158-
159-
coordinates = data[:, :-1]
160-
values = data[:, -1]
145+
from rocketpy.mathutils.function import Function
161146

162-
if np.unique(coordinates, axis=0).shape[0] != coordinates.shape[0]:
163-
return None
164-
165-
axes = [np.unique(coordinates[:, i]) for i in range(len(variable_names))]
166-
expected_size = int(np.prod([axis.size for axis in axes]))
167-
if expected_size != coordinates.shape[0]:
168-
return None
169-
170-
sorting_keys = [coordinates[:, i] for i in range(len(variable_names) - 1, -1, -1)]
171-
sorted_indices = np.lexsort(tuple(sorting_keys))
172-
sorted_coordinates = coordinates[sorted_indices]
173-
sorted_values = values[sorted_indices]
174-
175-
expected_coordinates = np.column_stack(
176-
[axis_values.ravel() for axis_values in np.meshgrid(*axes, indexing="ij")]
177-
)
178-
if not np.allclose(sorted_coordinates, expected_coordinates, rtol=0, atol=1e-12):
179-
return None
180-
181-
grid_data = sorted_values.reshape(tuple(axis.size for axis in axes))
182-
return function_class(
183-
(axes, grid_data),
184-
inputs=variable_names,
185-
outputs=[coeff_name],
186-
interpolation="regular_grid",
187-
extrapolation=extrapolation,
147+
return Function.from_regular_grid_csv(
148+
csv_source,
149+
variable_names,
150+
coeff_name,
151+
extrapolation,
188152
)
189153

190154

@@ -195,7 +159,7 @@ def load_generic_surface_csv(file_path, coeff_name): # pylint: disable=too-many
195159
variables among: alpha, beta, mach, reynolds, pitch_rate, yaw_rate,
196160
roll_rate.
197161
"""
198-
function_class = _get_function_class()
162+
from rocketpy.mathutils.function import Function
199163

200164
independent_vars = [
201165
"alpha",
@@ -247,7 +211,7 @@ def load_generic_surface_csv(file_path, coeff_name): # pylint: disable=too-many
247211
extrapolation="natural",
248212
)
249213
if csv_func is None:
250-
csv_func = function_class(
214+
csv_func = Function(
251215
file_path,
252216
interpolation="linear",
253217
extrapolation="natural",
@@ -266,7 +230,7 @@ def wrapper(alpha, beta, mach, reynolds, pitch_rate, yaw_rate, roll_rate):
266230
selected_args = [args_by_name[col] for col in ordered_present_columns]
267231
return csv_func(*selected_args)
268232

269-
return function_class(
233+
return Function(
270234
wrapper,
271235
independent_vars,
272236
[coeff_name],
@@ -281,7 +245,7 @@ def load_rocket_drag_csv(file_path, coeff_name): # pylint: disable=too-many-sta
281245
Supports either headerless two-column (mach, coefficient) tables or
282246
header-based multi-variable CSV tables.
283247
"""
284-
function_class = _get_function_class()
248+
from rocketpy.mathutils.function import Function
285249

286250
independent_vars = [
287251
"alpha",
@@ -321,7 +285,7 @@ def _is_numeric(value):
321285
)
322286

323287
if is_headerless_two_column:
324-
csv_func = function_class(
288+
csv_func = Function(
325289
file_path,
326290
interpolation="linear",
327291
extrapolation="constant",
@@ -338,7 +302,7 @@ def mach_wrapper(
338302
):
339303
return csv_func(mach)
340304

341-
return function_class(
305+
return Function(
342306
mach_wrapper,
343307
independent_vars,
344308
[coeff_name],
@@ -374,7 +338,7 @@ def mach_wrapper(
374338
extrapolation="constant",
375339
)
376340
if csv_func is None:
377-
csv_func = function_class(
341+
csv_func = Function(
378342
file_path,
379343
interpolation="linear",
380344
extrapolation="constant",
@@ -393,7 +357,7 @@ def wrapper(alpha, beta, mach, reynolds, pitch_rate, yaw_rate, roll_rate):
393357
selected_args = [args_by_name[col] for col in ordered_present_columns]
394358
return csv_func(*selected_args)
395359

396-
return function_class(
360+
return Function(
397361
wrapper,
398362
independent_vars,
399363
[coeff_name],

0 commit comments

Comments
 (0)