|
4 | 4 | import jax.numpy as jnp |
5 | 5 |
|
6 | 6 | from amici import amiciModulePath |
7 | | - |
8 | | -from ..exporters.template import apply_template |
| 7 | +from amici.exporters.template import apply_template |
9 | 8 |
|
10 | 9 |
|
11 | 10 | class Flatten(eqx.Module): |
@@ -185,7 +184,7 @@ def _generate_layer(layer: "Layer", indent: int, ilayer: int) -> str: # noqa: F |
185 | 184 | layer_map = { |
186 | 185 | "Dropout1d": "eqx.nn.Dropout", |
187 | 186 | "Dropout2d": "eqx.nn.Dropout", |
188 | | - "Flatten": "amici.jax.Flatten", |
| 187 | + "Flatten": "amici.export.jax.Flatten", |
189 | 188 | } |
190 | 189 |
|
191 | 190 | # mapping of keyword argument names in sciml yaml format to equinox/custom amici implementations |
@@ -321,9 +320,9 @@ def _process_activation_call(node: "Node") -> str: # noqa: F821 |
321 | 320 | "hardtanh": "jax.nn.hard_tanh", |
322 | 321 | "hardsigmoid": "jax.nn.hard_sigmoid", |
323 | 322 | "hardswish": "jax.nn.hard_swish", |
324 | | - "tanhshrink": "amici.jax.tanhshrink", |
| 323 | + "tanhshrink": "amici.export.jax.tanhshrink", |
325 | 324 | "softsign": "jax.nn.soft_sign", |
326 | | - "cat": "amici.jax.cat", |
| 325 | + "cat": "amici.export.jax.cat", |
327 | 326 | } |
328 | 327 |
|
329 | 328 | # Validate hardtanh parameters |
|
0 commit comments