Skip to content

Commit 3e7eb8f

Browse files
committed
fix: ensure 't' keyword argument defaults to 0 in _call_integral and format code in build method
1 parent a41ef10 commit 3e7eb8f

2 files changed

Lines changed: 7 additions & 3 deletions

File tree

brainpy/integrators/base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,10 @@ def state_delays(self, value):
141141
raise ValueError('Cannot set "state_delays" by users.')
142142

143143
def _call_integral(self, *args, **kwargs):
144+
kwargs = dict(kwargs)
145+
t = kwargs.get('t', None)
146+
kwargs['t'] = 0. if t is None else t
147+
144148
if _during_compile:
145149
jaxpr, out_shapes = jax.make_jaxpr(self.integral, return_shape=True)(**kwargs)
146150
outs = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *jax.tree.leaves(kwargs))

brainpy/integrators/ode/explicit_rk.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -178,8 +178,7 @@ def __init__(self,
178178

179179
def build(self):
180180
# step stage
181-
common.step(self.variables, C.DT,
182-
self.A, self.C, self.code_lines, self.parameters)
181+
common.step(self.variables, C.DT, self.A, self.C, self.code_lines, self.parameters)
183182
# variable update
184183
return_args = common.update(self.variables, C.DT, self.B, self.code_lines)
185184
# returns
@@ -189,7 +188,8 @@ def build(self):
189188
code_scope={k: v for k, v in self.code_scope.items()},
190189
code_lines=self.code_lines,
191190
show_code=self.show_code,
192-
func_name=self.func_name)
191+
func_name=self.func_name
192+
)
193193

194194

195195
class Euler(ExplicitRKIntegrator):

0 commit comments

Comments
 (0)