1- from typing import Callable
1+ from typing import Callable , List
22from math import gcd
33from fractions import Fraction
44
55import numpy as np
66from casadi import MX , exp , vertcat
77
88from bioptim import (
9- ConfigureProblem ,
109 DynamicsEvaluation ,
1110 NonLinearProgram ,
12- OptimalControlProgram ,
11+ StateDynamics ,
12+ DynamicsFunctions ,
13+ OdeSolver ,
14+ States ,
1315)
1416
1517from cocofest .models .state_configure import StateConfigure
1618from cocofest .models .fes_model import FesModel
1719
1820
19- class DingModelFrequency (FesModel ):
21+ class DingModelFrequency (FesModel , StateDynamics ):
2022 """
2123 This is a custom model of the Bioptim package. As CustomModel, some methods are mandatory and must be implemented.
2224 to make it work with bioptim.
@@ -38,8 +40,9 @@ def __init__(
3840 stim_time : list [float ] = None ,
3941 previous_stim : dict = None ,
4042 sum_stim_truncation : int = 20 ,
43+ ** kwargs ,
4144 ):
42- super ().__init__ ()
45+ super ().__init__ (name = model_name , ** kwargs )
4346 self ._model_name = model_name
4447 self ._muscle_name = muscle_name
4548 self .sum_stim_truncation = sum_stim_truncation
@@ -68,6 +71,30 @@ def __init__(
6871 self .km_rest = KM_REST_DEFAULT
6972 self .fmax = 315.5 # Maximum force (N) at 100 Hz
7073
74+ # ---- Muscle relationship ---- #
75+ self .fes_model = None
76+ self .force_length_relationship = 1
77+ self .force_velocity_relationship = 1
78+ self .passive_force_relationship = 0
79+
80+ # --- Configure variables --- #
81+ @property
82+ def state_configuration_functions (self ) -> List [States | Callable ]:
83+ return [StateConfigure ().configure_all_muscle_states ]
84+
85+ @property
86+ def control_configuration_functions (self ) -> List [States | Callable ]:
87+ return []
88+
89+ @property
90+ def algebraic_configuration_functions (self ) -> List [States | Callable ]:
91+ return []
92+
93+ @property
94+ def extra_configuration_functions (self ) -> List [States | Callable ]:
95+ return []
96+
97+ # --- Set model parameters --- #
7198 def set_a_rest (self , model , a_rest : MX | float ):
7299 # models is required for bioptim compatibility
73100 self .a_rest = a_rest
@@ -106,8 +133,8 @@ def serialize(self) -> tuple[Callable, dict]:
106133
107134 # ---- Needed for the example ---- #
108135 @property
109- def name_dof (self , with_muscle_name : bool = False ) -> list [str ]:
110- muscle_name = "_" + self .muscle_name if self .muscle_name and with_muscle_name else ""
136+ def name_dofs (self ) -> list [str ]:
137+ muscle_name = "_" + self .muscle_name if self .muscle_name is not None else ""
111138 return ["Cn" + muscle_name , "F" + muscle_name ]
112139
113140 @property
@@ -155,48 +182,41 @@ def get_lambda_i(nb_stim: int, pulse_intensity: MX | float) -> list[MX | float]:
155182 # ---- Model's dynamics ---- #
156183 def system_dynamics (
157184 self ,
158- cn : MX ,
159- f : MX ,
160- t : MX = None ,
161- t_stim_prev : list [MX ] = None ,
162- force_length_relationship : MX | float = 1 ,
163- force_velocity_relationship : MX | float = 1 ,
164- passive_force_relationship : MX | float = 0 ,
185+ time : MX ,
186+ states : MX ,
187+ controls : MX ,
188+ numerical_timeseries : MX ,
165189 ) -> MX :
166190 """
167191 The system dynamics is the function that describes the models.
168192
169193 Parameters
170194 ----------
171- cn: MX
172- The value of the ca_troponin_complex (unitless)
173- f: MX
174- The value of the force (N)
175- t: MX
176- The current time at which the dynamics is evaluated (s)
177- t_stim_prev: list[MX]
178- The time list of the previous stimulations (s)
179- force_length_relationship: MX | float
180- The force length relationship value (unitless)
181- force_velocity_relationship: MX | float
182- The force velocity relationship value (unitless)
183- passive_force_relationship: MX | float
184- The passive force relationship value (unitless)
195+ time: MX
196+ The system's current node time
197+ states: MX
198+ The state of the system CN, F
199+ controls: MX
200+ The controls of the system, none
201+ numerical_timeseries: MX
202+ The numerical timeseries of the system
185203
186204 Returns
187205 -------
188206 The value of the derivative of each state dx/dt at the current time t
189207 """
208+ t = time
209+ cn = states [0 ]
210+ f = states [1 ]
211+ t_stim_prev = numerical_timeseries
212+
190213 cn_dot = self .calculate_cn_dot (cn , t , t_stim_prev )
191214 f_dot = self .f_dot_fun (
192215 cn ,
193216 f ,
194217 self .a_rest ,
195218 self .tau1_rest ,
196219 self .km_rest ,
197- force_length_relationship = force_length_relationship ,
198- force_velocity_relationship = force_velocity_relationship ,
199- passive_force_relationship = passive_force_relationship ,
200220 ) # Equation n°2
201221 return vertcat (cn_dot , f_dot )
202222
@@ -285,9 +305,6 @@ def f_dot_fun(
285305 a : MX | float ,
286306 tau1 : MX | float ,
287307 km : MX | float ,
288- force_length_relationship : MX | float = 1 ,
289- force_velocity_relationship : MX | float = 1 ,
290- passive_force_relationship : MX | float = 0 ,
291308 ) -> MX | float :
292309 """
293310 Parameters
@@ -302,34 +319,23 @@ def f_dot_fun(
302319 The previous step value of time_state_force_no_cross_bridge (s)
303320 km: MX | float
304321 The previous step value of cross_bridges (unitless)
305- force_length_relationship: MX | float
306- The force length relationship value (unitless)
307- force_velocity_relationship: MX | float
308- The force velocity relationship value (unitless)
309- passive_force_relationship: MX | float
310- The passive force relationship value (unitless)
311-
312322 Returns
313323 -------
314324 The value of the derivative force (N)
315325 """
316326 return (a * (cn / (km + cn )) - (f / (tau1 + self .tau2 * (cn / (km + cn ))))) * (
317- force_length_relationship * force_velocity_relationship + passive_force_relationship
327+ self . force_length_relationship * self . force_velocity_relationship + self . passive_force_relationship
318328 ) # Equation n°2
319329
320- @staticmethod
321330 def dynamics (
331+ self ,
322332 time : MX ,
323333 states : MX ,
324334 controls : MX ,
325335 parameters : MX ,
326336 algebraic_states : MX ,
327337 numerical_timeseries : MX ,
328338 nlp : NonLinearProgram ,
329- fes_model = None ,
330- force_length_relationship : MX | float = 1 ,
331- force_velocity_relationship : MX | float = 1 ,
332- passive_force_relationship : MX | float = 0 ,
333339 ) -> DynamicsEvaluation :
334340 """
335341 Functional electrical stimulation dynamic
@@ -350,57 +356,31 @@ def dynamics(
350356 The numerical timeseries of the system
351357 nlp: NonLinearProgram
352358 A reference to the phase
353- fes_model: DingModelFrequency
354- The current phase fes model
355- force_length_relationship: MX | float
356- The force length relationship value (unitless)
357- force_velocity_relationship: MX | float
358- The force velocity relationship value (unitless)
359- passive_force_relationship: MX | float
360- The passive force relationship value (unitless)
361359 Returns
362360 -------
363361 The derivative of the states in the tuple[MX] format
364362 """
365- model = fes_model if fes_model else nlp .model
363+ model = self . fes_model if self . fes_model else nlp .model
366364 dxdt_fun = model .system_dynamics
365+ dxdt = dxdt_fun (
366+ time = time ,
367+ states = states ,
368+ controls = controls ,
369+ numerical_timeseries = numerical_timeseries ,
370+ )
371+
372+ defects = None
373+ if isinstance (nlp .dynamics_type .ode_solver , OdeSolver .COLLOCATION ):
374+ states_dot_list = []
375+ for key in model .name_dofs :
376+ states_dot_list .append (DynamicsFunctions .get (nlp .states_dot [key ], nlp .states_dot .scaled .cx ))
377+ defects = vertcat (* states_dot_list ) - dxdt
367378
368379 return DynamicsEvaluation (
369- dxdt = dxdt_fun (
370- cn = states [0 ],
371- f = states [1 ],
372- t = time ,
373- t_stim_prev = numerical_timeseries ,
374- force_length_relationship = force_length_relationship ,
375- force_velocity_relationship = force_velocity_relationship ,
376- passive_force_relationship = passive_force_relationship ,
377- ),
380+ dxdt = dxdt ,
381+ defects = defects ,
378382 )
379383
380- def declare_ding_variables (
381- self ,
382- ocp : OptimalControlProgram ,
383- nlp : NonLinearProgram ,
384- numerical_data_timeseries : dict [str , np .ndarray ] = None ,
385- contact_type : list = (),
386- ):
387- """
388- Tell the program which variables are states and controls.
389- The user is expected to use the ConfigureProblem.configure_xxx functions.
390- Parameters
391- ----------
392- ocp: OptimalControlProgram
393- A reference to the ocp
394- nlp: NonLinearProgram
395- A reference to the phase
396- numerical_data_timeseries: dict[str, np.ndarray]
397- A list of values to pass to the dynamics at each node. Experimental external forces should be included here.
398- contact_type: list
399- A list of contact types. This is used to define the contact forces in the dynamics. Not used in this model.
400- """
401- StateConfigure ().configure_all_fes_model_states (ocp , nlp , fes_model = self )
402- ConfigureProblem .configure_dynamics_function (ocp , nlp , dyn_func = self .dynamics )
403-
404384 def _get_additional_previous_stim_time (self ):
405385 while len (self .previous_stim ["time" ]) < self .sum_stim_truncation :
406386 self .previous_stim ["time" ].insert (0 , - 10000000 )
0 commit comments