|
1 | | -import os |
2 | | -from typing import Union, Optional, Tuple, List, Dict, Any |
| 1 | +from typing import Any, Dict, List, Optional, Tuple, Union |
3 | 2 |
|
4 | 3 | import flax.linen as nn |
5 | 4 | import jax |
6 | | - |
7 | | -jax.config.update("jax_enable_x64", True) |
8 | 5 | import jax.numpy as jnp |
9 | 6 | import numpy as np |
10 | 7 | import pandas as pd |
|
14 | 11 |
|
15 | 12 | from csde.optimization import _zstat_generic2, optimize_ppi, optimize_ppi_gd |
16 | 13 |
|
| 14 | +jax.config.update("jax_enable_x64", True) |
| 15 | + |
17 | 16 |
|
18 | 17 | class PPIAbstractClass: |
19 | 18 | """ |
@@ -413,14 +412,14 @@ def likelihood(model_params, x, y): |
413 | 412 |
|
414 | 413 | all_grads = np.zeros((n_obs, self.n_params)) |
415 | 414 | for i in tqdm(range(0, n_obs, batch_size), desc="Gradient computation"): |
416 | | - x_batch = x[i : i + batch_size] |
417 | | - y_batch = y[i : i + batch_size] |
| 415 | + x_batch = x[i:i+batch_size] |
| 416 | + y_batch = y[i:i+batch_size] |
418 | 417 | n_obs_batch = x_batch.shape[0] |
419 | 418 | score = self.jit(jax.jacfwd(likelihood)) |
420 | 419 | grads = score(self.model_params, x_batch, y_batch) |
421 | 420 | grad_mu = np.array(grads["params"]["mu"].reshape(n_obs_batch, -1)) |
422 | 421 | grad_mu0 = np.array(grads["params"]["mu0"].reshape(n_obs_batch, -1)) |
423 | | - all_grads[i : i + batch_size] = np.hstack([grad_mu, grad_mu0]) |
| 422 | + all_grads[i:i+batch_size] = np.hstack([grad_mu, grad_mu0]) |
424 | 423 | return np.array(all_grads) |
425 | 424 |
|
426 | 425 | def _construct_contrast(self, feature_id: int, idx_a: int) -> np.ndarray: |
|
0 commit comments