Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions brainpy/integrators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,10 @@ def state_delays(self, value):
raise ValueError('Cannot set "state_delays" by users.')

def _call_integral(self, *args, **kwargs):
kwargs = dict(kwargs)
t = kwargs.get('t', None)
kwargs['t'] = 0. if t is None else t

if _during_compile:
jaxpr, out_shapes = jax.make_jaxpr(self.integral, return_shape=True)(**kwargs)
outs = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *jax.tree.leaves(kwargs))
Expand Down
6 changes: 3 additions & 3 deletions brainpy/integrators/ode/explicit_rk.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,7 @@ def __init__(self,

def build(self):
# step stage
common.step(self.variables, C.DT,
self.A, self.C, self.code_lines, self.parameters)
common.step(self.variables, C.DT, self.A, self.C, self.code_lines, self.parameters)
# variable update
return_args = common.update(self.variables, C.DT, self.B, self.code_lines)
# returns
Expand All @@ -189,7 +188,8 @@ def build(self):
code_scope={k: v for k, v in self.code_scope.items()},
code_lines=self.code_lines,
show_code=self.show_code,
func_name=self.func_name)
func_name=self.func_name
)


class Euler(ExplicitRKIntegrator):
Expand Down
Loading