Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/run_tests_win.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,16 @@ jobs:
conda info
conda list

- name: Install extra dependencies
run: conda install pip pytest-cov black pytest pytest-cov codecov packaging -cconda-forge

- name: Install bioptim on Windows
run: |
pwd
cd external
./bioptim_install_windows.sh 4 ${{ env.PREFIX_WINDOWS }}
cd ..

- name: Install extra dependencies
run: conda install pytest-cov black pytest pytest-cov codecov packaging -cconda-forge

- name: Run tests with code coverage
run: pytest -v --color=yes --cov-report term-missing --cov=cocofest --cov-report=xml:coverage.xml tests/shard${{ matrix.shard }}

Expand Down
6 changes: 3 additions & 3 deletions cocofest/custom_objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def minimize_overall_muscle_fatigue(controller: PenaltyController) -> MX:
-------
The sum of each force scaling factor
"""
muscle_name_list = controller.model.bio_model.muscle_names
muscle_name_list = controller.model.muscle_names
muscle_model = controller.model.muscles_dynamics_model
muscle_fatigue = vertcat(
*[
Expand All @@ -47,7 +47,7 @@ def minimize_overall_muscle_force_production(controller: PenaltyController) -> M
-------
The sum of each force
"""
muscle_name_list = controller.model.bio_model.muscle_names
muscle_name_list = controller.model.muscle_names
muscle_model = controller.model.muscles_dynamics_model
muscle_force = vertcat(
*[
Expand All @@ -72,7 +72,7 @@ def minimize_overall_stimulation_charge(controller: PenaltyController) -> MX:
The sum of each stimulation control
"""
if isinstance(controller.model, FesMskModel):
muscle_name_list = controller.model.bio_model.muscle_names
muscle_name_list = controller.model.muscle_names
if isinstance(controller.model.muscles_dynamics_model[0], DingModelPulseWidthFrequency):
stim_charge = vertcat(
*[
Expand Down
28 changes: 6 additions & 22 deletions cocofest/integration/ivp_fes.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import numpy as np
from bioptim import (
ControlType,
DynamicsList,
InitialGuessList,
OdeSolver,
OptimalControlProgram,
ParameterList,
PhaseDynamics,
BoundsList,
InterpolationType,
Solution,
Expand All @@ -20,6 +18,7 @@
from cocofest.models.ding2007.ding2007_with_fatigue import DingModelPulseWidthFrequencyWithFatigue
from cocofest.models.hmed2018.hmed2018 import DingModelPulseIntensityFrequency
from cocofest.models.hmed2018.hmed2018_with_fatigue import DingModelPulseIntensityFrequencyWithFatigue
from cocofest.optimization.fes_ocp import OcpFes


class IvpFes:
Expand Down Expand Up @@ -90,9 +89,10 @@ def __init__(
numerical_data_time_series, stim_idx_at_node_list = self.model.get_numerical_data_time_series(
self.n_shooting, self.final_time
)

self.ode_solver = self.ivp_parameters["ode_solver"]
self._declare_dynamics(numerical_data_time_series)
self.dynamics_options = OcpFes.declare_dynamics_options(
numerical_time_series=numerical_data_time_series, ode_solver=self.ode_solver
)

(
self.x_init,
Expand Down Expand Up @@ -285,7 +285,7 @@ def _prepare_fake_ocp(self):

return OptimalControlProgram(
bio_model=[self.model],
dynamics=self.dynamics,
dynamics=self.dynamics_options,
n_shooting=self.n_shooting,
phase_time=self.final_time,
x_init=self.x_init,
Expand Down Expand Up @@ -318,20 +318,6 @@ def integrate(
duplicated_times=duplicated_times,
)

def _declare_dynamics(self, numerical_data_time_series=None):

self.dynamics = DynamicsList()
self.dynamics.add(
self.model.declare_ding_variables,
dynamic_function=self.model.dynamics,
expand_dynamics=True,
expand_continuity=False,
phase=0,
phase_dynamics=PhaseDynamics.SHARED_DURING_THE_PHASE,
numerical_data_timeseries=numerical_data_time_series,
ode_solver=self.ode_solver,
)

def build_initial_guess_from_ocp(self, ocp, stim_idx_at_node_list=None):
"""
Build a state, control, parameters and stochastic initial guesses for each phases from a given ocp
Expand All @@ -341,9 +327,7 @@ def build_initial_guess_from_ocp(self, ocp, stim_idx_at_node_list=None):
p = InitialGuessList()
s = InitialGuessList()

muscle_name = "_" + ocp.model.muscle_name if ocp.model.muscle_name else ""
for j in range(len(self.model.name_dof)):
key = ocp.model.name_dof[j] + muscle_name
for j, key in enumerate(self.model.name_dofs):
x.add(key=key, initial_guess=ocp.model.standard_rest_values()[j], phase=0)

if ocp.controls_keys:
Expand Down
158 changes: 69 additions & 89 deletions cocofest/models/ding2003/ding2003.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
from typing import Callable
from typing import Callable, List
from math import gcd
from fractions import Fraction

import numpy as np
from casadi import MX, exp, vertcat

from bioptim import (
ConfigureProblem,
DynamicsEvaluation,
NonLinearProgram,
OptimalControlProgram,
StateDynamics,
DynamicsFunctions,
OdeSolver,
States,
)

from cocofest.models.state_configure import StateConfigure
from cocofest.models.fes_model import FesModel


class DingModelFrequency(FesModel):
class DingModelFrequency(FesModel, StateDynamics):
"""
This is a custom model of the Bioptim package. As CustomModel, some methods are mandatory and must be implemented.
to make it work with bioptim.
Expand All @@ -38,8 +40,9 @@ def __init__(
stim_time: list[float] = None,
previous_stim: dict = None,
sum_stim_truncation: int = 20,
**kwargs,
):
super().__init__()
super().__init__(name=model_name, **kwargs)
self._model_name = model_name
self._muscle_name = muscle_name
self.sum_stim_truncation = sum_stim_truncation
Expand Down Expand Up @@ -68,6 +71,30 @@ def __init__(
self.km_rest = KM_REST_DEFAULT
self.fmax = 315.5 # Maximum force (N) at 100 Hz

# ---- Muscle relationship ---- #
self.fes_model = None
self.force_length_relationship = 1
self.force_velocity_relationship = 1
self.passive_force_relationship = 0

# --- Configure variables --- #
@property
def state_configuration_functions(self) -> List[States | Callable]:
return [StateConfigure().configure_all_muscle_states]

@property
def control_configuration_functions(self) -> List[States | Callable]:
return []

@property
def algebraic_configuration_functions(self) -> List[States | Callable]:
return []

@property
def extra_configuration_functions(self) -> List[States | Callable]:
return []

# --- Set model parameters --- #
def set_a_rest(self, model, a_rest: MX | float):
# models is required for bioptim compatibility
self.a_rest = a_rest
Expand Down Expand Up @@ -106,8 +133,8 @@ def serialize(self) -> tuple[Callable, dict]:

# ---- Needed for the example ---- #
@property
def name_dof(self, with_muscle_name: bool = False) -> list[str]:
muscle_name = "_" + self.muscle_name if self.muscle_name and with_muscle_name else ""
def name_dofs(self) -> list[str]:
muscle_name = "_" + self.muscle_name if self.muscle_name is not None else ""
return ["Cn" + muscle_name, "F" + muscle_name]

@property
Expand Down Expand Up @@ -155,48 +182,41 @@ def get_lambda_i(nb_stim: int, pulse_intensity: MX | float) -> list[MX | float]:
# ---- Model's dynamics ---- #
def system_dynamics(
self,
cn: MX,
f: MX,
t: MX = None,
t_stim_prev: list[MX] = None,
force_length_relationship: MX | float = 1,
force_velocity_relationship: MX | float = 1,
passive_force_relationship: MX | float = 0,
time: MX,
states: MX,
controls: MX,
numerical_timeseries: MX,
) -> MX:
"""
The system dynamics is the function that describes the models.

Parameters
----------
cn: MX
The value of the ca_troponin_complex (unitless)
f: MX
The value of the force (N)
t: MX
The current time at which the dynamics is evaluated (s)
t_stim_prev: list[MX]
The time list of the previous stimulations (s)
force_length_relationship: MX | float
The force length relationship value (unitless)
force_velocity_relationship: MX | float
The force velocity relationship value (unitless)
passive_force_relationship: MX | float
The passive force relationship value (unitless)
time: MX
The system's current node time
states: MX
The state of the system CN, F
controls: MX
The controls of the system, none
numerical_timeseries: MX
The numerical timeseries of the system

Returns
-------
The value of the derivative of each state dx/dt at the current time t
"""
t = time
cn = states[0]
f = states[1]
t_stim_prev = numerical_timeseries

cn_dot = self.calculate_cn_dot(cn, t, t_stim_prev)
f_dot = self.f_dot_fun(
cn,
f,
self.a_rest,
self.tau1_rest,
self.km_rest,
force_length_relationship=force_length_relationship,
force_velocity_relationship=force_velocity_relationship,
passive_force_relationship=passive_force_relationship,
) # Equation n°2
return vertcat(cn_dot, f_dot)

Expand Down Expand Up @@ -285,9 +305,6 @@ def f_dot_fun(
a: MX | float,
tau1: MX | float,
km: MX | float,
force_length_relationship: MX | float = 1,
force_velocity_relationship: MX | float = 1,
passive_force_relationship: MX | float = 0,
) -> MX | float:
"""
Parameters
Expand All @@ -302,34 +319,23 @@ def f_dot_fun(
The previous step value of time_state_force_no_cross_bridge (s)
km: MX | float
The previous step value of cross_bridges (unitless)
force_length_relationship: MX | float
The force length relationship value (unitless)
force_velocity_relationship: MX | float
The force velocity relationship value (unitless)
passive_force_relationship: MX | float
The passive force relationship value (unitless)

Returns
-------
The value of the derivative force (N)
"""
return (a * (cn / (km + cn)) - (f / (tau1 + self.tau2 * (cn / (km + cn))))) * (
force_length_relationship * force_velocity_relationship + passive_force_relationship
self.force_length_relationship * self.force_velocity_relationship + self.passive_force_relationship
) # Equation n°2

@staticmethod
def dynamics(
self,
time: MX,
states: MX,
controls: MX,
parameters: MX,
algebraic_states: MX,
numerical_timeseries: MX,
nlp: NonLinearProgram,
fes_model=None,
force_length_relationship: MX | float = 1,
force_velocity_relationship: MX | float = 1,
passive_force_relationship: MX | float = 0,
) -> DynamicsEvaluation:
"""
Functional electrical stimulation dynamic
Expand All @@ -350,57 +356,31 @@ def dynamics(
The numerical timeseries of the system
nlp: NonLinearProgram
A reference to the phase
fes_model: DingModelFrequency
The current phase fes model
force_length_relationship: MX | float
The force length relationship value (unitless)
force_velocity_relationship: MX | float
The force velocity relationship value (unitless)
passive_force_relationship: MX | float
The passive force relationship value (unitless)
Returns
-------
The derivative of the states in the tuple[MX] format
"""
model = fes_model if fes_model else nlp.model
model = self.fes_model if self.fes_model else nlp.model
dxdt_fun = model.system_dynamics
dxdt = dxdt_fun(
time=time,
states=states,
controls=controls,
numerical_timeseries=numerical_timeseries,
)

defects = None
if isinstance(nlp.dynamics_type.ode_solver, OdeSolver.COLLOCATION):
states_dot_list = []
for key in model.name_dofs:
states_dot_list.append(DynamicsFunctions.get(nlp.states_dot[key], nlp.states_dot.scaled.cx))
defects = vertcat(*states_dot_list) - dxdt

return DynamicsEvaluation(
dxdt=dxdt_fun(
cn=states[0],
f=states[1],
t=time,
t_stim_prev=numerical_timeseries,
force_length_relationship=force_length_relationship,
force_velocity_relationship=force_velocity_relationship,
passive_force_relationship=passive_force_relationship,
),
dxdt=dxdt,
defects=defects,
)

def declare_ding_variables(
self,
ocp: OptimalControlProgram,
nlp: NonLinearProgram,
numerical_data_timeseries: dict[str, np.ndarray] = None,
contact_type: list = (),
):
"""
Tell the program which variables are states and controls.
The user is expected to use the ConfigureProblem.configure_xxx functions.
Parameters
----------
ocp: OptimalControlProgram
A reference to the ocp
nlp: NonLinearProgram
A reference to the phase
numerical_data_timeseries: dict[str, np.ndarray]
A list of values to pass to the dynamics at each node. Experimental external forces should be included here.
contact_type: list
A list of contact types. This is used to define the contact forces in the dynamics. Not used in this model.
"""
StateConfigure().configure_all_fes_model_states(ocp, nlp, fes_model=self)
ConfigureProblem.configure_dynamics_function(ocp, nlp, dyn_func=self.dynamics)

def _get_additional_previous_stim_time(self):
while len(self.previous_stim["time"]) < self.sum_stim_truncation:
self.previous_stim["time"].insert(0, -10000000)
Expand Down
Loading
Loading