Skip to content

Commit cad8bf7

Browse files
authored
Merge pull request #21 from Kev1CO/bioptim_update
Updating Cocofest to last Bioptim version
2 parents 3536d8e + e131cd2 commit cad8bf7

51 files changed

Lines changed: 988 additions & 2105 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/workflows/run_tests_win.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,16 +40,16 @@ jobs:
4040
conda info
4141
conda list
4242
43+
- name: Install extra dependencies
44+
run: conda install pip pytest-cov black pytest pytest-cov codecov packaging -cconda-forge
45+
4346
- name: Install bioptim on Windows
4447
run: |
4548
pwd
4649
cd external
4750
./bioptim_install_windows.sh 4 ${{ env.PREFIX_WINDOWS }}
4851
cd ..
4952
50-
- name: Install extra dependencies
51-
run: conda install pytest-cov black pytest pytest-cov codecov packaging -cconda-forge
52-
5353
- name: Run tests with code coverage
5454
run: pytest -v --color=yes --cov-report term-missing --cov=cocofest --cov-report=xml:coverage.xml tests/shard${{ matrix.shard }}
5555

cocofest/custom_objectives.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def minimize_overall_muscle_fatigue(controller: PenaltyController) -> MX:
2323
-------
2424
The sum of each force scaling factor
2525
"""
26-
muscle_name_list = controller.model.bio_model.muscle_names
26+
muscle_name_list = controller.model.muscle_names
2727
muscle_model = controller.model.muscles_dynamics_model
2828
muscle_fatigue = vertcat(
2929
*[
@@ -47,7 +47,7 @@ def minimize_overall_muscle_force_production(controller: PenaltyController) -> M
4747
-------
4848
The sum of each force
4949
"""
50-
muscle_name_list = controller.model.bio_model.muscle_names
50+
muscle_name_list = controller.model.muscle_names
5151
muscle_model = controller.model.muscles_dynamics_model
5252
muscle_force = vertcat(
5353
*[
@@ -72,7 +72,7 @@ def minimize_overall_stimulation_charge(controller: PenaltyController) -> MX:
7272
The sum of each stimulation control
7373
"""
7474
if isinstance(controller.model, FesMskModel):
75-
muscle_name_list = controller.model.bio_model.muscle_names
75+
muscle_name_list = controller.model.muscle_names
7676
if isinstance(controller.model.muscles_dynamics_model[0], DingModelPulseWidthFrequency):
7777
stim_charge = vertcat(
7878
*[

cocofest/integration/ivp_fes.py

Lines changed: 6 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
import numpy as np
22
from bioptim import (
33
ControlType,
4-
DynamicsList,
54
InitialGuessList,
65
OdeSolver,
76
OptimalControlProgram,
87
ParameterList,
9-
PhaseDynamics,
108
BoundsList,
119
InterpolationType,
1210
Solution,
@@ -20,6 +18,7 @@
2018
from cocofest.models.ding2007.ding2007_with_fatigue import DingModelPulseWidthFrequencyWithFatigue
2119
from cocofest.models.hmed2018.hmed2018 import DingModelPulseIntensityFrequency
2220
from cocofest.models.hmed2018.hmed2018_with_fatigue import DingModelPulseIntensityFrequencyWithFatigue
21+
from cocofest.optimization.fes_ocp import OcpFes
2322

2423

2524
class IvpFes:
@@ -90,9 +89,10 @@ def __init__(
9089
numerical_data_time_series, stim_idx_at_node_list = self.model.get_numerical_data_time_series(
9190
self.n_shooting, self.final_time
9291
)
93-
9492
self.ode_solver = self.ivp_parameters["ode_solver"]
95-
self._declare_dynamics(numerical_data_time_series)
93+
self.dynamics_options = OcpFes.declare_dynamics_options(
94+
numerical_time_series=numerical_data_time_series, ode_solver=self.ode_solver
95+
)
9696

9797
(
9898
self.x_init,
@@ -285,7 +285,7 @@ def _prepare_fake_ocp(self):
285285

286286
return OptimalControlProgram(
287287
bio_model=[self.model],
288-
dynamics=self.dynamics,
288+
dynamics=self.dynamics_options,
289289
n_shooting=self.n_shooting,
290290
phase_time=self.final_time,
291291
x_init=self.x_init,
@@ -318,20 +318,6 @@ def integrate(
318318
duplicated_times=duplicated_times,
319319
)
320320

321-
def _declare_dynamics(self, numerical_data_time_series=None):
322-
323-
self.dynamics = DynamicsList()
324-
self.dynamics.add(
325-
self.model.declare_ding_variables,
326-
dynamic_function=self.model.dynamics,
327-
expand_dynamics=True,
328-
expand_continuity=False,
329-
phase=0,
330-
phase_dynamics=PhaseDynamics.SHARED_DURING_THE_PHASE,
331-
numerical_data_timeseries=numerical_data_time_series,
332-
ode_solver=self.ode_solver,
333-
)
334-
335321
def build_initial_guess_from_ocp(self, ocp, stim_idx_at_node_list=None):
336322
"""
337323
Build a state, control, parameters and stochastic initial guesses for each phases from a given ocp
@@ -341,9 +327,7 @@ def build_initial_guess_from_ocp(self, ocp, stim_idx_at_node_list=None):
341327
p = InitialGuessList()
342328
s = InitialGuessList()
343329

344-
muscle_name = "_" + ocp.model.muscle_name if ocp.model.muscle_name else ""
345-
for j in range(len(self.model.name_dof)):
346-
key = ocp.model.name_dof[j] + muscle_name
330+
for j, key in enumerate(self.model.name_dofs):
347331
x.add(key=key, initial_guess=ocp.model.standard_rest_values()[j], phase=0)
348332

349333
if ocp.controls_keys:

cocofest/models/ding2003/ding2003.py

Lines changed: 69 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,24 @@
1-
from typing import Callable
1+
from typing import Callable, List
22
from math import gcd
33
from fractions import Fraction
44

55
import numpy as np
66
from casadi import MX, exp, vertcat
77

88
from bioptim import (
9-
ConfigureProblem,
109
DynamicsEvaluation,
1110
NonLinearProgram,
12-
OptimalControlProgram,
11+
StateDynamics,
12+
DynamicsFunctions,
13+
OdeSolver,
14+
States,
1315
)
1416

1517
from cocofest.models.state_configure import StateConfigure
1618
from cocofest.models.fes_model import FesModel
1719

1820

19-
class DingModelFrequency(FesModel):
21+
class DingModelFrequency(FesModel, StateDynamics):
2022
"""
2123
This is a custom model of the Bioptim package. As CustomModel, some methods are mandatory and must be implemented.
2224
to make it work with bioptim.
@@ -38,8 +40,9 @@ def __init__(
3840
stim_time: list[float] = None,
3941
previous_stim: dict = None,
4042
sum_stim_truncation: int = 20,
43+
**kwargs,
4144
):
42-
super().__init__()
45+
super().__init__(name=model_name, **kwargs)
4346
self._model_name = model_name
4447
self._muscle_name = muscle_name
4548
self.sum_stim_truncation = sum_stim_truncation
@@ -68,6 +71,30 @@ def __init__(
6871
self.km_rest = KM_REST_DEFAULT
6972
self.fmax = 315.5 # Maximum force (N) at 100 Hz
7073

74+
# ---- Muscle relationship ---- #
75+
self.fes_model = None
76+
self.force_length_relationship = 1
77+
self.force_velocity_relationship = 1
78+
self.passive_force_relationship = 0
79+
80+
# --- Configure variables --- #
81+
@property
82+
def state_configuration_functions(self) -> List[States | Callable]:
83+
return [StateConfigure().configure_all_muscle_states]
84+
85+
@property
86+
def control_configuration_functions(self) -> List[States | Callable]:
87+
return []
88+
89+
@property
90+
def algebraic_configuration_functions(self) -> List[States | Callable]:
91+
return []
92+
93+
@property
94+
def extra_configuration_functions(self) -> List[States | Callable]:
95+
return []
96+
97+
# --- Set model parameters --- #
7198
def set_a_rest(self, model, a_rest: MX | float):
7299
# models is required for bioptim compatibility
73100
self.a_rest = a_rest
@@ -106,8 +133,8 @@ def serialize(self) -> tuple[Callable, dict]:
106133

107134
# ---- Needed for the example ---- #
108135
@property
109-
def name_dof(self, with_muscle_name: bool = False) -> list[str]:
110-
muscle_name = "_" + self.muscle_name if self.muscle_name and with_muscle_name else ""
136+
def name_dofs(self) -> list[str]:
137+
muscle_name = "_" + self.muscle_name if self.muscle_name is not None else ""
111138
return ["Cn" + muscle_name, "F" + muscle_name]
112139

113140
@property
@@ -155,48 +182,41 @@ def get_lambda_i(nb_stim: int, pulse_intensity: MX | float) -> list[MX | float]:
155182
# ---- Model's dynamics ---- #
156183
def system_dynamics(
157184
self,
158-
cn: MX,
159-
f: MX,
160-
t: MX = None,
161-
t_stim_prev: list[MX] = None,
162-
force_length_relationship: MX | float = 1,
163-
force_velocity_relationship: MX | float = 1,
164-
passive_force_relationship: MX | float = 0,
185+
time: MX,
186+
states: MX,
187+
controls: MX,
188+
numerical_timeseries: MX,
165189
) -> MX:
166190
"""
167191
The system dynamics is the function that describes the models.
168192
169193
Parameters
170194
----------
171-
cn: MX
172-
The value of the ca_troponin_complex (unitless)
173-
f: MX
174-
The value of the force (N)
175-
t: MX
176-
The current time at which the dynamics is evaluated (s)
177-
t_stim_prev: list[MX]
178-
The time list of the previous stimulations (s)
179-
force_length_relationship: MX | float
180-
The force length relationship value (unitless)
181-
force_velocity_relationship: MX | float
182-
The force velocity relationship value (unitless)
183-
passive_force_relationship: MX | float
184-
The passive force relationship value (unitless)
195+
time: MX
196+
The system's current node time
197+
states: MX
198+
The state of the system CN, F
199+
controls: MX
200+
The controls of the system, none
201+
numerical_timeseries: MX
202+
The numerical timeseries of the system
185203
186204
Returns
187205
-------
188206
The value of the derivative of each state dx/dt at the current time t
189207
"""
208+
t = time
209+
cn = states[0]
210+
f = states[1]
211+
t_stim_prev = numerical_timeseries
212+
190213
cn_dot = self.calculate_cn_dot(cn, t, t_stim_prev)
191214
f_dot = self.f_dot_fun(
192215
cn,
193216
f,
194217
self.a_rest,
195218
self.tau1_rest,
196219
self.km_rest,
197-
force_length_relationship=force_length_relationship,
198-
force_velocity_relationship=force_velocity_relationship,
199-
passive_force_relationship=passive_force_relationship,
200220
) # Equation n°2
201221
return vertcat(cn_dot, f_dot)
202222

@@ -285,9 +305,6 @@ def f_dot_fun(
285305
a: MX | float,
286306
tau1: MX | float,
287307
km: MX | float,
288-
force_length_relationship: MX | float = 1,
289-
force_velocity_relationship: MX | float = 1,
290-
passive_force_relationship: MX | float = 0,
291308
) -> MX | float:
292309
"""
293310
Parameters
@@ -302,34 +319,23 @@ def f_dot_fun(
302319
The previous step value of time_state_force_no_cross_bridge (s)
303320
km: MX | float
304321
The previous step value of cross_bridges (unitless)
305-
force_length_relationship: MX | float
306-
The force length relationship value (unitless)
307-
force_velocity_relationship: MX | float
308-
The force velocity relationship value (unitless)
309-
passive_force_relationship: MX | float
310-
The passive force relationship value (unitless)
311-
312322
Returns
313323
-------
314324
The value of the derivative force (N)
315325
"""
316326
return (a * (cn / (km + cn)) - (f / (tau1 + self.tau2 * (cn / (km + cn))))) * (
317-
force_length_relationship * force_velocity_relationship + passive_force_relationship
327+
self.force_length_relationship * self.force_velocity_relationship + self.passive_force_relationship
318328
) # Equation n°2
319329

320-
@staticmethod
321330
def dynamics(
331+
self,
322332
time: MX,
323333
states: MX,
324334
controls: MX,
325335
parameters: MX,
326336
algebraic_states: MX,
327337
numerical_timeseries: MX,
328338
nlp: NonLinearProgram,
329-
fes_model=None,
330-
force_length_relationship: MX | float = 1,
331-
force_velocity_relationship: MX | float = 1,
332-
passive_force_relationship: MX | float = 0,
333339
) -> DynamicsEvaluation:
334340
"""
335341
Functional electrical stimulation dynamic
@@ -350,57 +356,31 @@ def dynamics(
350356
The numerical timeseries of the system
351357
nlp: NonLinearProgram
352358
A reference to the phase
353-
fes_model: DingModelFrequency
354-
The current phase fes model
355-
force_length_relationship: MX | float
356-
The force length relationship value (unitless)
357-
force_velocity_relationship: MX | float
358-
The force velocity relationship value (unitless)
359-
passive_force_relationship: MX | float
360-
The passive force relationship value (unitless)
361359
Returns
362360
-------
363361
The derivative of the states in the tuple[MX] format
364362
"""
365-
model = fes_model if fes_model else nlp.model
363+
model = self.fes_model if self.fes_model else nlp.model
366364
dxdt_fun = model.system_dynamics
365+
dxdt = dxdt_fun(
366+
time=time,
367+
states=states,
368+
controls=controls,
369+
numerical_timeseries=numerical_timeseries,
370+
)
371+
372+
defects = None
373+
if isinstance(nlp.dynamics_type.ode_solver, OdeSolver.COLLOCATION):
374+
states_dot_list = []
375+
for key in model.name_dofs:
376+
states_dot_list.append(DynamicsFunctions.get(nlp.states_dot[key], nlp.states_dot.scaled.cx))
377+
defects = vertcat(*states_dot_list) - dxdt
367378

368379
return DynamicsEvaluation(
369-
dxdt=dxdt_fun(
370-
cn=states[0],
371-
f=states[1],
372-
t=time,
373-
t_stim_prev=numerical_timeseries,
374-
force_length_relationship=force_length_relationship,
375-
force_velocity_relationship=force_velocity_relationship,
376-
passive_force_relationship=passive_force_relationship,
377-
),
380+
dxdt=dxdt,
381+
defects=defects,
378382
)
379383

380-
def declare_ding_variables(
381-
self,
382-
ocp: OptimalControlProgram,
383-
nlp: NonLinearProgram,
384-
numerical_data_timeseries: dict[str, np.ndarray] = None,
385-
contact_type: list = (),
386-
):
387-
"""
388-
Tell the program which variables are states and controls.
389-
The user is expected to use the ConfigureProblem.configure_xxx functions.
390-
Parameters
391-
----------
392-
ocp: OptimalControlProgram
393-
A reference to the ocp
394-
nlp: NonLinearProgram
395-
A reference to the phase
396-
numerical_data_timeseries: dict[str, np.ndarray]
397-
A list of values to pass to the dynamics at each node. Experimental external forces should be included here.
398-
contact_type: list
399-
A list of contact types. This is used to define the contact forces in the dynamics. Not used in this model.
400-
"""
401-
StateConfigure().configure_all_fes_model_states(ocp, nlp, fes_model=self)
402-
ConfigureProblem.configure_dynamics_function(ocp, nlp, dyn_func=self.dynamics)
403-
404384
def _get_additional_previous_stim_time(self):
405385
while len(self.previous_stim["time"]) < self.sum_stim_truncation:
406386
self.previous_stim["time"].insert(0, -10000000)

0 commit comments

Comments
 (0)