|
4 | 4 | from scipy.integrate import odeint |
5 | 5 | from scipy.optimize import curve_fit |
6 | 6 | import matplotlib.pyplot as plt |
| 7 | +import plotly.graph_objs as go |
7 | 8 |
|
8 | 9 |
|
9 | 10 | class Model: |
@@ -186,6 +187,63 @@ def dlmos(self): |
186 | 187 | return self.cbt() - self.cbt_to_dlmo |
187 | 188 |
|
188 | 189 |
|
| 190 | + def plot(self, states=False, dlmo=False, cbtmin=False): |
| 191 | + # Create a new plotly figure |
| 192 | + fig = go.Figure() |
| 193 | + |
| 194 | + if states: |
| 195 | + # Calculate number of states available |
| 196 | + states = self.model_states.shape[1] |
| 197 | + |
| 198 | + if self.model_name == 'Forger' or self.model_name == 'Jewett': |
| 199 | + labels = ['x','xc', 'Light Drive'] |
| 200 | + elif self.model_name == 'HannaySP': |
| 201 | + labels = ['Amplitude','Phase', 'Light Drive'] |
| 202 | + else: |
| 203 | + labels = ['Ventral Amplitude', 'Dorsal Amplitude', 'Ventral Phase', 'Dorsal Phase', 'Light Drive'] |
| 204 | + # Iterate over states and plot them |
| 205 | + for i in range(states): |
| 206 | + fig.add_trace(go.Scatter( |
| 207 | + x=self.data.index.astype(str), |
| 208 | + y=self.model_states[:, i], |
| 209 | + name=f'{labels[i]}', |
| 210 | + )) |
| 211 | + fig.update_layout( |
| 212 | + title='Model States', |
| 213 | + xaxis=dict(title='Time'), |
| 214 | + yaxis=dict(title='Model States'), |
| 215 | + ) |
| 216 | + return fig |
| 217 | + |
| 218 | + if dlmo: |
| 219 | + # Plot daily predicted DLMO |
| 220 | + fig.add_trace(go.Scatter( |
| 221 | + x=pd.Series(self.data.index.date.astype(str)).unique(), |
| 222 | + y=self.dlmos() % 24, |
| 223 | + name='Predicted DLMO', |
| 224 | + )) |
| 225 | + fig.update_layout( |
| 226 | + title='Predicted DLMO', |
| 227 | + xaxis=dict(title='Day'), |
| 228 | + yaxis=dict(title='DLMO time'), |
| 229 | + ) |
| 230 | + return fig |
| 231 | + |
| 232 | + if cbtmin: |
| 233 | + # Plot daily predicted DLMO |
| 234 | + fig.add_trace(go.Scatter( |
| 235 | + x=pd.Series(self.data.index.date.astype(str)).unique(), |
| 236 | + y=self.cbt() % 24, |
| 237 | + name='Predicted CBTmin', |
| 238 | + )) |
| 239 | + fig.update_layout( |
| 240 | + title='Predicted CBTmin', |
| 241 | + xaxis=dict(title='Day'), |
| 242 | + yaxis=dict(title='CBTmin time'), |
| 243 | + ) |
| 244 | + return fig |
| 245 | + |
| 246 | + |
189 | 247 | class Forger(Model): |
190 | 248 | """ |
191 | 249 | Implements the mathematical model of human circadian rhythms developed by Forger, Jewett and Kronauer [1]. |
@@ -280,6 +338,7 @@ def __init__( |
280 | 338 | # self.initial_conditions = np.array([-0.0843259, -1.09607546, 0.45584306]) |
281 | 339 | # self.inputs = inputs |
282 | 340 | # self.time = time |
| 341 | + self.model_name = self.__class__.__name__ |
283 | 342 | self.taux = taux |
284 | 343 | self.mu = mu |
285 | 344 | self.g = g |
@@ -484,6 +543,7 @@ def __init__( |
484 | 543 | # self.initial_conditions= np.array([-0.10097101, -1.21985662, 0.50529415]) |
485 | 544 | # self.inputs = inputs |
486 | 545 | # self.time = time |
| 546 | + self.model_name = self.__class__.__name__ |
487 | 547 | self.taux = taux |
488 | 548 | self.mu = mu |
489 | 549 | self.g = g |
@@ -697,6 +757,7 @@ def __init__( |
697 | 757 | # self.initial_conditions = np.array([0.82041911, 1.71383697, 0.52318122]) |
698 | 758 | # self.inputs = inputs |
699 | 759 | # self.time = time |
| 760 | + self.model_name = self.__class__.__name__ |
700 | 761 | self.tau = tau |
701 | 762 | self.k = k |
702 | 763 | self.gamma = gamma |
@@ -947,6 +1008,7 @@ def __init__( |
947 | 1008 | # self.initial_conditions = np.array([0.82423745, 0.82304996, 1.75233424, 1.863457, 0.52318122]) |
948 | 1009 | # self.inputs = inputs |
949 | 1010 | # self.time = time |
| 1011 | + self.model_name = self.__class__.__name__ |
950 | 1012 | self.tauv = tauv |
951 | 1013 | self.taud = taud |
952 | 1014 | self.kvv = kvv |
@@ -1166,8 +1228,8 @@ def __init__( |
1166 | 1228 | # Extract time and light vector |
1167 | 1229 | if time is None or inputs is None: |
1168 | 1230 | if data is not None: |
1169 | | - self.time = np.asarray((data.index - data.index.min()).total_seconds() / 3600) |
1170 | | - self.inputs = np.asarray(data.values) |
| 1231 | + self.time_vector = np.asarray((data.index - data.index.min()).total_seconds() / 3600) |
| 1232 | + self.light_vector = np.asarray(data.values) |
1171 | 1233 | else: |
1172 | 1234 | raise ValueError("Must provide either light time series (data) or input and time.") |
1173 | 1235 | else: |
|
0 commit comments