diff --git a/brainpy/__init__.py b/brainpy/__init__.py index 0f0b1b4f..eda23a0e 100644 --- a/brainpy/__init__.py +++ b/brainpy/__init__.py @@ -1,146 +1,146 @@ -# -*- coding: utf-8 -*- -# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -__version__ = "2.7.7" -__version_info__ = tuple(map(int, __version__.split("."))) - -from brainpy import _errors as errors -# fundamental supporting modules -from brainpy import check, tools -# Part: Math Foundation # -# ----------------------- # -# math foundation -from brainpy import math -from brainpy import mixin -# Part: Toolbox # -# --------------- # -# modules of toolbox -from . import ( - connect, # synaptic connection - initialize, # weight initialization - optim, # gradient descent optimizers - losses, # loss functions - measure, # methods for data analysis - inputs, # methods for generating input currents - encoding, # encoding schema - checkpoints, # checkpoints - check, # error checking - algorithms, # online or offline training algorithms -) -from .math import BrainPyObject - -# convenient alias -conn = connect -init = initialize - -# numerical integrators -from brainpy import integrators -from brainpy.integrators import ode, sde, fde -from brainpy.integrators.base import (Integrator as Integrator) -from brainpy.integrators.joint_eq import (JointEq as JointEq) -from brainpy.integrators.runner import (IntegratorRunner as IntegratorRunner) -from brainpy.integrators.ode.generic import (odeint as odeint) -from brainpy.integrators.sde.generic import (sdeint as sdeint) -from brainpy.integrators.fde.generic import (fdeint as fdeint) - -# Part: Models # -# -------------- # - -# base classes -from brainpy.dynsys import ( - DynamicalSystem as DynamicalSystem, - DynSysGroup as DynSysGroup, # collectors - Sequential as Sequential, - Dynamic as Dynamic, # category - Projection as Projection, - receive_update_input, # decorators - receive_update_output, - not_receive_update_input, - not_receive_update_output, -) - -DynamicalSystemNS = DynamicalSystem -Network = DynSysGroup -# delays -from brainpy.delay import ( - VarDelay as VarDelay, -) - -# building blocks -from brainpy import ( - dnn, layers, # module for dnn layers - dyn, # module for modeling dynamics -) - -NeuGroup = NeuGroupNS = dyn.NeuDyn -dyn.DynamicalSystem = DynamicalSystem - -# common tools -from brainpy.context import (share as share) -from brainpy.helpers import ( - reset_level as reset_level, - reset_state as reset_state, - save_state as save_state, - load_state as load_state, - clear_input as clear_input -) - -# Part: Running # -# --------------- # -from brainpy.runners import (DSRunner as DSRunner) -from brainpy.transform import (LoopOverTime as LoopOverTime, ) -from brainpy import (running as running) - -# Part: Training # -# ---------------- # -from brainpy.train.base import (DSTrainer as DSTrainer, ) -from brainpy.train.back_propagation import (BPTT as BPTT, - BPFF as BPFF, ) -from brainpy.train.online import (OnlineTrainer as OnlineTrainer, - ForceTrainer as ForceTrainer, ) -from brainpy.train.offline import (OfflineTrainer as OfflineTrainer, - RidgeTrainer as RidgeTrainer, ) - -# Part: Analysis # -# ---------------- # -from brainpy import (analysis as analysis) - -# Part: Others # -# ---------------- # -import brainpy.visualization as visualize - -# Part: Deprecations # -# -------------------- # -from brainpy import train -from brainpy import ( - channels, # channel models - neurons, # neuron groups - synapses, # synapses - rates, # rate models - synouts, # synaptic output - synplast, # synaptic plasticity -) -from brainpy.math.object_transform.base import Base as Base - -from brainpy.math.object_transform.collectors import ( - ArrayCollector as ArrayCollector, - Collector as Collector, -) - -optimizers = optim - -# New package -from brainpy import state +# -*- coding: utf-8 -*- +# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +__version__ = "2.7.8" +__version_info__ = tuple(map(int, __version__.split("."))) + +from brainpy import _errors as errors +# fundamental supporting modules +from brainpy import check, tools +# Part: Math Foundation # +# ----------------------- # +# math foundation +from brainpy import math +from brainpy import mixin +# Part: Toolbox # +# --------------- # +# modules of toolbox +from . import ( + connect, # synaptic connection + initialize, # weight initialization + optim, # gradient descent optimizers + losses, # loss functions + measure, # methods for data analysis + inputs, # methods for generating input currents + encoding, # encoding schema + checkpoints, # checkpoints + check, # error checking + algorithms, # online or offline training algorithms +) +from .math import BrainPyObject + +# convenient alias +conn = connect +init = initialize + +# numerical integrators +from brainpy import integrators +from brainpy.integrators import ode, sde, fde +from brainpy.integrators.base import (Integrator as Integrator) +from brainpy.integrators.joint_eq import (JointEq as JointEq) +from brainpy.integrators.runner import (IntegratorRunner as IntegratorRunner) +from brainpy.integrators.ode.generic import (odeint as odeint) +from brainpy.integrators.sde.generic import (sdeint as sdeint) +from brainpy.integrators.fde.generic import (fdeint as fdeint) + +# Part: Models # +# -------------- # + +# base classes +from brainpy.dynsys import ( + DynamicalSystem as DynamicalSystem, + DynSysGroup as DynSysGroup, # collectors + Sequential as Sequential, + Dynamic as Dynamic, # category + Projection as Projection, + receive_update_input, # decorators + receive_update_output, + not_receive_update_input, + not_receive_update_output, +) + +DynamicalSystemNS = DynamicalSystem +Network = DynSysGroup +# delays +from brainpy.delay import ( + VarDelay as VarDelay, +) + +# building blocks +from brainpy import ( + dnn, layers, # module for dnn layers + dyn, # module for modeling dynamics +) + +NeuGroup = NeuGroupNS = dyn.NeuDyn +dyn.DynamicalSystem = DynamicalSystem + +# common tools +from brainpy.context import (share as share) +from brainpy.helpers import ( + reset_level as reset_level, + reset_state as reset_state, + save_state as save_state, + load_state as load_state, + clear_input as clear_input +) + +# Part: Running # +# --------------- # +from brainpy.runners import (DSRunner as DSRunner) +from brainpy.transform import (LoopOverTime as LoopOverTime, ) +from brainpy import (running as running) + +# Part: Training # +# ---------------- # +from brainpy.train.base import (DSTrainer as DSTrainer, ) +from brainpy.train.back_propagation import (BPTT as BPTT, + BPFF as BPFF, ) +from brainpy.train.online import (OnlineTrainer as OnlineTrainer, + ForceTrainer as ForceTrainer, ) +from brainpy.train.offline import (OfflineTrainer as OfflineTrainer, + RidgeTrainer as RidgeTrainer, ) + +# Part: Analysis # +# ---------------- # +from brainpy import (analysis as analysis) + +# Part: Others # +# ---------------- # +import brainpy.visualization as visualize + +# Part: Deprecations # +# -------------------- # +from brainpy import train +from brainpy import ( + channels, # channel models + neurons, # neuron groups + synapses, # synapses + rates, # rate models + synouts, # synaptic output + synplast, # synaptic plasticity +) +from brainpy.math.object_transform.base import Base as Base + +from brainpy.math.object_transform.collectors import ( + ArrayCollector as ArrayCollector, + Collector as Collector, +) + +optimizers = optim + +# New package +from brainpy import state diff --git a/brainpy/integrators/_jaxpr_to_source_code.py b/brainpy/integrators/_jaxpr_to_source_code.py deleted file mode 100644 index a00bbe9c..00000000 --- a/brainpy/integrators/_jaxpr_to_source_code.py +++ /dev/null @@ -1,1137 +0,0 @@ -# Modified from: https://github.com/dlwh/jax_sourceror -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -import ast -import enum -import warnings -from collections.abc import MutableMapping, MutableSet -from contextlib import contextmanager -from dataclasses import dataclass, field -from typing import Callable, Union - -import jax -import jax.numpy as jnp -import numpy as np -from jax._src.sharding_impls import UNSPECIFIED - -if jax.__version__ >= '0.5.0': - from jax.extend.core import Literal, Var, Jaxpr -else: - from jax.core import Primitive, Literal, Var, Jaxpr - -__all__ = [ - 'fn_to_python_code', - 'jaxpr_to_python_code', -] - - -class IdentitySet(MutableSet): - """Set that compares objects by identity. - - This is a set that compares objects by identity instead of equality. It is - useful for storing objects that are not hashable or that should be compared - by identity. - - This is a mutable set, but it does not support the ``__hash__`` method and - therefore cannot be used as a dictionary key or as an element of another set. - """ - - def __init__(self, iterable=None): - self._data = {} - if iterable is not None: - self.update(iterable) - - def __contains__(self, value): - return id(value) in self._data - - def __iter__(self): - return iter(self._data.values()) - - def __len__(self): - return len(self._data) - - def add(self, value): - self._data[id(value)] = value - - def discard(self, value): - self._data.pop(id(value), None) - - def __repr__(self): - return f"IdentitySet({list(repr(x) for x in self._data.values())})" - - def __str__(self): - return f"IdentitySet({list(str(x) for x in self._data.values())})" - - -class IdentityMap(MutableMapping): - """Map that compares keys by identity. - - This is a map that compares keys by identity instead of equality. It is - useful for storing objects that are not hashable or that should be compared - by identity. - - This is a mutable mapping, but it does not support the ``__hash__`` method - and therefore cannot be used as a dictionary key or as an element of another - set. - """ - - def __init__(self, iterable=None): - self._data = {} - if iterable is not None: - self.update(iterable) - - def __contains__(self, key): - return id(key) in self._data - - def __getitem__(self, key): - return self._data[id(key)] - - def __setitem__(self, key, value): - self._data[id(key)] = value - - def __delitem__(self, key): - del self._data[id(key)] - - def __iter__(self): - return iter(self._data.values()) - - def __len__(self): - return len(self._data) - - def __repr__(self): - return f"IdentityMap({list(repr(x) for x in self._data.values())})" - - def __str__(self): - return f"IdentityMap({list(str(x) for x in self._data.values())})" - - -@dataclass -class SourcerorState: - """State for the auto-minimizer. Basically just in charge of naming variables.""" - _var_names: IdentityMap[Var, str] = field(default_factory=IdentityMap) - _skolem_count: int = 0 - - def name(self, var, ctx=ast.Load()) -> ast.Name: - return ast.Name(id=self.str_name(var), ctx=ctx) - - def str_name(self, var: Var): - # Names things in a way vaguely compatible with - # JAX's naming scheme, which is 'a'-'z' followed - # by 'aa'-'az' etc. - if var in self._var_names: - return self._var_names[var] - else: - cur_count = len(self._var_names) - name = "" - while cur_count >= 26: - name += chr(ord('a') + cur_count % 26) - cur_count //= 26 - - name += chr(ord('a') + cur_count) - - name = name[::-1] - - self._var_names[var] = name - - return name - - def skolem(self, prefix: str): - self._skolem_count += 1 - return f"{prefix}_{self._skolem_count}" - - -prefix_imports = set() - - -@contextmanager -def catch_imports(): - try: - prefix_imports.clear() - yield - finally: - prefix_imports.clear() - - -def fn_to_python_code(fn, *args, **kwargs): - """ - Given a function which is defined by jax primitives and the function arguments, - return the Python code that would be generated by JAX for that function. - - :param fn: The function to generate code for - :param args: The positional arguments to the function - :param kwargs: The keyword arguments to the function - :return: The Python code that would be generated by JAX for that function - """ - closed_jaxpr = jax.make_jaxpr(fn)(*args, **kwargs) - jaxpr = constant_fold_jaxpr(closed_jaxpr.jaxpr) - state = SourcerorState() - try: - name = fn.__name__ - except AttributeError: - name = "unknown" - with catch_imports(): - node = jaxpr_to_py_ast(state, jaxpr, fn_name=name) - node = _maybe_wrap_fn_for_leaves(node, fn, len(args) + len(kwargs)) - ast.fix_missing_locations(node) - source = ast.unparse(node) - if len(prefix_imports): - source = "\n".join(prefix_imports) + "\n\n" + source - return source - - -def jaxpr_to_python_code(jaxpr: Jaxpr, - fn_name: str = "generated_function"): - """ - Given a JAX jaxpr, return the Python code that would be generated by JAX for that jaxpr. - - :param jaxpr: The jaxpr to generate code. - :param fn_name: The name of the function to generate code. - :return: The Python code that would be generated by JAX for that jaxpr - """ - jaxpr = constant_fold_jaxpr(jaxpr) - state = SourcerorState() - with catch_imports(): - node = jaxpr_to_py_ast(state, jaxpr, fn_name=fn_name) - ast.fix_missing_locations(node) - source = ast.unparse(node) - if len(prefix_imports): - source = "\n".join(prefix_imports) + "\n\n" + source - return source - - -def register_prim_handler(prim_name, handler): - """ - Register a handler for a primitive for automin - :param prim_name: - :param handler: - :return: - """ - if prim_name in prim_to_python: - warnings.warn(f"Overwriting handler for primitive {prim_name}") - prim_to_python[prim_name] = handler - - -def register_prim_as(prim_name): - """ - Decorator to register a handler for a primitive. - - :param prim_name: - :return: - """ - - def decorator(fn): - register_prim_handler(prim_name, fn) - return fn - - return decorator - - -def _assign_stmt(call_expr: Callable): - """ - Create a handler for a primitive that is a simple assignment. - :param call_expr: - :return: - """ - - def binop_fn(state, eqn): - invars = [_astify_atom(state, v) for v in eqn.invars] - outvars = _astify_outvars(state, eqn.outvars) - return ast.Assign( - outvars, - call_expr( - *invars, - **{k: _astify_value(v) for k, v in eqn.params.items()} - ) - ) - - return binop_fn - - -def _binop_fn(op: ast.operator): - return _assign_stmt(lambda x, y: ast.BinOp(left=x, op=op, right=y)) - - -def _cmpop_fn(op: ast.cmpop): - return _assign_stmt(lambda x, y: ast.Compare(left=x, ops=[op], comparators=[y])) - - -def normal_fn(fn_name): - """ - Create a handler for a normal function call. - :param fn_name: - :return: - """ - return _assign_stmt( - lambda *args, **kwargs: ast.Call( - func=ast.Name(id=fn_name, ctx=ast.Load()), - args=list(args), - keywords=[ast.keyword(arg=k, value=v) for k, v in kwargs.items()] - ) - ) - - -def _reduce_fn(fn_name: str): - def reduce_fn_inner(state: SourcerorState, eqn): - invars = [_astify_atom(state, v) for v in eqn.invars] - outvars = _astify_outvars(state, eqn.outvars) - if eqn.params: - params = eqn.params.copy() - params['axis'] = tuple(params['axes']) - del params['axes'] - call_op = ast.Call( - func=ast.Name(id=fn_name, ctx=ast.Load()), - args=invars, - keywords=[ast.keyword(arg=k, value=_astify_value(v)) for k, v in params.items()] - ) - else: - call_op = ast.Call( - func=ast.Name(id=fn_name, ctx=ast.Load()), - args=invars, - keywords=[] - ) - - return ast.Assign(outvars, call_op) - - return reduce_fn_inner - - -prim_to_python = dict() - -register_prim_handler('add', _binop_fn(ast.Add())) -register_prim_handler('sub', _binop_fn(ast.Sub())) -register_prim_handler('mul', _binop_fn(ast.Mult())) -register_prim_handler('div', _binop_fn(ast.Div())) -register_prim_handler('neg', normal_fn('jax.lax.neg')) -register_prim_handler('lt', _cmpop_fn(ast.Lt())) -register_prim_handler('gt', _cmpop_fn(ast.Gt())) -register_prim_handler('le', _cmpop_fn(ast.LtE())) -register_prim_handler('ge', _cmpop_fn(ast.GtE())) -register_prim_handler('eq', _cmpop_fn(ast.Eq())) -register_prim_handler('ne', _cmpop_fn(ast.NotEq())) -register_prim_handler('min', normal_fn('jax.lax.min')) -register_prim_handler('max', normal_fn('jax.lax.max')) -register_prim_handler('select_n', normal_fn('jax.lax.select_n')) -register_prim_handler('squeeze', normal_fn('jax.lax.squeeze')) -register_prim_handler('broadcast', normal_fn('jax.lax.broadcast')) -register_prim_handler('reduce_sum', _reduce_fn('jax.numpy.sum')) -register_prim_handler('transpose', normal_fn('jax.lax.transpose')) - - -def _maybe_wrap_fn_for_leaves(node, f, num_args): - if len(node.args.args) == num_args: - return node - - wrapped_node = ast.FunctionDef( - name=f.__name__, - args=ast.arguments( - args=[], - vararg=ast.arg(arg="args", annotation=None), - kwarg=ast.arg(arg="kwargs", annotation=None), - kwonlyargs=[], kw_defaults=[], defaults=[], - posonlyargs=[] - ), - body=[ - node, - ast.Return( - ast.Call( - func=ast.Name(id=node.name, ctx=ast.Load()), - args=[ - ast.Starred( - ast.Call( - func=ast.Attribute(value=ast.Name(id="jax", ctx=ast.Load()), - attr="tree_leaves", - ctx=ast.Load()), - args=[ast.Tuple(elts=[ast.Name(id="args", ctx=ast.Load()), - ast.Name(id="kwargs", ctx=ast.Load())], - ctx=ast.Load())], - keywords=[] - ) - ) - ], - keywords=[] - ) - ), - ], - decorator_list=[] - ) - - return wrapped_node - - -def jaxpr_to_py_ast(state: SourcerorState, - jaxpr: Jaxpr, - fn_name: str = "function"): - # Generate argument declarations - ast_args = [ast.arg(arg=state.str_name(var), annotation=None) - for var in jaxpr.invars] - ast_args = ast.arguments(args=ast_args, - vararg=None, - kwonlyargs=[], - kw_defaults=[], - kwarg=None, - defaults=[], - posonlyargs=[]) - - stmts = [] - - # Generate body of the function - for eqn in jaxpr.eqns: - prim = str(eqn.primitive) - if prim in prim_to_python: - eqn_stmts = prim_to_python[prim](state, eqn) - else: - eqn_stmts = normal_fn(prim)(state, eqn) - - if isinstance(eqn_stmts, list): - stmts.extend(eqn_stmts) - else: - stmts.append(eqn_stmts) - - # Generate return statement - if len(jaxpr.outvars) == 1: - returns = state.name(jaxpr.outvars[0]) - else: - returns = ast.Tuple(elts=[state.name(var) for var in jaxpr.outvars], ctx=ast.Load()) - stmts.append(ast.Return(value=returns)) - - return ast.FunctionDef(name=fn_name, args=ast_args, body=stmts, decorator_list=[]) - - -def constant_fold_jaxpr(jaxpr: Jaxpr): - """ - Given a jaxpr, return a new jaxpr with all constant folding done. - """ - return partial_eval_jaxpr(jaxpr, {}) - - -def partial_eval_jaxpr(jaxpr, env): - env = env.copy() - new_eqns = [] - - def read(var): - if isinstance(var, Literal): - return var.val - else: - return env.get(var, None) - - def read_or_self(var): - out = read(var) - if out is None: - return var - elif isinstance(out, Var): - return out - elif isinstance(out, Literal): - return Literal(out.val, var.aval) - else: - assert not isinstance(out, Jaxpr) - return Literal(out, var.aval) - - for eqn in jaxpr.eqns: - vals = [read(var) for var in eqn.invars] - if eqn.primitive.name in constant_fold_blacklist: - new_eqns.append(eqn) - elif all(val is not None for val in vals): - # go ahead and eval it - out = _eval_eqn(eqn, vals) - - # two options: either it's a jaxpr result (partial eval) or it's a value or a list of values - if isinstance(out, Jaxpr): - # we need to inline this - new_eqns.extend(out.eqns) - out = out.outvars - elif not isinstance(out, tuple) and not isinstance(out, list): - out = (out,) - - for var, val in zip(eqn.outvars, out): - assert not isinstance(val, Jaxpr) - if isinstance(val, Literal): - env[var] = val.val - else: - env[var] = val - else: - new_eqns.append(eqn) - - # now that we've evaled everything, inline all the constants - out_eqns = [] - for eqn in new_eqns: - eqn = eqn.replace(invars=tuple(read_or_self(var) for var in eqn.invars)) - out_eqns.append(eqn) - - invars_still_used = IdentitySet() - for eqn in out_eqns: - for var in eqn.invars: - invars_still_used.add(var) - - invars = tuple(var for var in jaxpr.invars if var in invars_still_used) - - # sub in any constants for outvars - outvars = tuple(read_or_self(var) for var in jaxpr.outvars) - - return jaxpr.replace(eqns=out_eqns, outvars=outvars, invars=invars) - - -def _eval_eqn(eqn, vals) -> Union[Jaxpr, tuple, list, jnp.ndarray]: - if eqn.primitive.name == "closed_call": - assert eqn.primitive.call_primitive == True - assert eqn.primitive.map_primitive == False - - out = partial_eval_jaxpr(eqn.params['call_jaxpr'].jaxpr, - {var: val for var, val in zip(eqn.params['call_jaxpr'].jaxpr.invars, vals)}) - elif eqn.primitive.name == "scan": - out = eqn.primitive.bind(*vals, **eqn.params) - else: - out = eqn.primitive.bind(*vals, **eqn.params) - return out - - -@register_prim_as('dot_general') -def _astify_dot_general(state, eqn): - x, y = eqn.invars - d = eqn.params['dimension_numbers'] - precision = eqn.params['precision'] - preferred_element_type = eqn.params['preferred_element_type'] - - has_dtype = preferred_element_type is None or x.aval.dtype == y.aval.dtype == preferred_element_type - - # recognize simple matmul case - if d == (((1,), (0,)), ((), ())) and precision == None: - invars = [_astify_atom(state, x), _astify_atom(state, y)] - outvars = _astify_outvars(state, eqn.outvars) - out = ast.Assign(targets=outvars, value=ast.Call( - func=ast.Attribute(value=ast.Name(id='jax.numpy', ctx=ast.Load()), attr='matmul', ctx=ast.Load()), - args=invars, - keywords=[])) - if not has_dtype: - out = ast.Assign(targets=outvars, - value=ast.Call(func=ast.Attribute(value=out.value, attr='astype', ctx=ast.Load()), - args=[_astify_value(preferred_element_type)], keywords=[])) - - return out - - # TODO: convert to einsum? - - invars = [_astify_atom(state, x), - _astify_atom(state, y), - _astify_value(d), - _astify_value(precision), - _astify_value(preferred_element_type)] - outvars = _astify_outvars(state, eqn.outvars) - return ast.Assign( - targets=outvars, - value=ast.Call( - func=ast.Attribute(value=ast.Name(id='jax.lax', ctx=ast.Load()), attr='dot_general', ctx=ast.Load()), - args=invars, - keywords=[] - ) - ) - - -@register_prim_as('dynamic_slice') -def _sourcify_dynamic_slice(state, eqn): - sliced = eqn.invars[0] - invars = ast.Tuple(elts=[_astify_atom(state, var) for var in eqn.invars[1:]], ctx=ast.Load()) - outvars = _astify_outvars(state, eqn.outvars) - params = [ast.keyword(arg=k, value=_astify_value(v)) for k, v in eqn.params.items()] - return ast.Assign( - targets=outvars, - value=ast.Call( - func=ast.Attribute( - value=ast.Name(id='jax.lax', ctx=ast.Load()), - attr='dynamic_slice', - ctx=ast.Load() - ), - args=[_astify_atom(state, sliced), invars], - keywords=params - ) - ) - - -@register_prim_as('slice') -def _sourcify_slice(state, eqn): - sliced = eqn.invars[0] - # invars = ast.Tuple(elts=[_astify_atom(state, var) for var in eqn.invars[1:]], ctx=ast.Load()) - outvars = _astify_outvars(state, eqn.outvars) - start_indices = eqn.params['start_indices'] - limit_indices = eqn.params['limit_indices'] - strides = eqn.params['strides'] - if strides is None: - strides = (None,) * len(start_indices) - indices = [_astify_value(slice(s, e, stride)) - for s, e, stride in zip(start_indices, limit_indices, strides)] - # params = [ast.keyword(arg=k, value=_astify_value(v)) for k, v in eqn.params.items()] - return ast.Assign( - targets=outvars, - value=ast.Subscript( - value=_astify_atom(state, sliced), - slice=ast.Tuple(elts=indices, ctx=ast.Load()), - ctx=ast.Load() - ) - ) - - -@register_prim_as('dynamic_update_slice') -def _sourcify_dynamic_update_slice(state, eqn): - sliced = eqn.invars[0] - # the first two arguments are the sliced array and the update array - # the remaining are start indices and should be packaged into a tuple - target = _astify_atom(state, eqn.invars[0]) - update = _astify_atom(state, eqn.invars[1]) - start_indices = maybe_tuple_vars([_astify_atom(state, var) for var in eqn.invars[2:]]) - outvars = _astify_outvars(state, eqn.outvars) - - return ast.Assign(targets=outvars, value=ast.Call( - func=ast.Attribute( - value=ast.Name(id='jax.lax', ctx=ast.Load()), - attr='dynamic_update_slice', - ctx=ast.Load() - ), - args=[target, update, start_indices], - keywords=[] - )) - - -@register_prim_as('convert_element_type') -def _astify_convert_element_type(state, eqn): - # now we use ast - outvars = _astify_outvars(state, eqn.outvars) - assert len(eqn.invars) == 1 - invar = _astify_atom(state, eqn.invars[0]) - dtype = _astify_value(eqn.params['new_dtype']) - return ast.Assign(targets=outvars, value=ast.Call( - func=ast.Attribute( - value=invar, - attr='astype', - ctx=ast.Load() - ), - args=[dtype], - keywords=[] - )) - - -def is_array(arr): - return isinstance(arr, (np.ndarray, np.generic, jnp.ndarray)) - - -def _astify_array(value): - assert is_array(value) - if isinstance(value, np.int64): - return ast.Constant(value=int(value)) - - if value.ndim == 0 and value.dtype in (jnp.float32, jnp.int32, jnp.bool_, jnp.int64): - return ast.Constant(value=value.item()) - - if value.ndim == 0: - dtype_value = _astify_value(value.dtype) - return ast.Call( - dtype_value, - args=[ast.Constant(value=value.item())], - keywords=[], - ) - - values = value.tolist() - - def rec_astify_list(values): - if isinstance(values, list): - return ast.List(elts=[rec_astify_list(val) for val in values], ctx=ast.Load()) - else: - return ast.Constant(value=values) - - return ast.Call( - func=ast.Attribute( - value=ast.Name(id='jax.numpy', ctx=ast.Load()), - attr='array', - ctx=ast.Load() - ), - args=[rec_astify_list(values)], - keywords=[ast.keyword(arg='dtype', - value=_astify_value(value.dtype))] - ) - - -def _astify_atom(state: SourcerorState, var: Union[Literal, Var]): - if isinstance(var, Literal): - return _astify_value(var.val) - elif isinstance(var, Var): - return state.name(var) - else: - raise NotImplementedError() - - -def _astify_value(value): - assert not isinstance(value, (Literal, Var)) - - if is_array(value): - return _astify_array(value) - elif isinstance(value, (int, bool, float, str, type(None))): - return ast.Constant(value=value) - elif isinstance(value, (tuple, list)): - return ast.Tuple(elts=[_astify_value(v) for v in value], ctx=ast.Load()) - elif isinstance(value, jnp.dtype): - # return ast.Call(func=ast.Attribute(value=ast.Name(id='jax.numpy', ctx=ast.Load()), attr='dtype', ctx=ast.Load()), args=[ast.Constant(value=str(value))], keywords=[]) - if value.name in ('float32', 'float64', 'int32', 'int64', 'bfloat16', 'float16'): - # return ast.Constant(value=getattr(jnp, value.name)) - return ast.Attribute( - value=ast.Name(id='jax.numpy', ctx=ast.Load()), - attr=value.name, - ctx=ast.Load() - ) - elif value.name == 'bool': - return ast.Attribute( - value=ast.Name(id='jax.numpy', ctx=ast.Load()), - attr='bool_', - ctx=ast.Load() - ) - else: - return ast.Call( - func=ast.Attribute(value=ast.Name(id='jax.numpy', ctx=ast.Load()), - attr='dtype', - ctx=ast.Load()), - args=[ast.Constant(value=str(value))], - keywords=[] - ) - elif value is UNSPECIFIED: - prefix_imports.add('from jax._src.sharding_impls import UNSPECIFIED') - return ast.Name(id='UNSPECIFIED', ctx=ast.Load()) - elif isinstance(value, enum.Enum): - return ast.Attribute( - value=ast.Name(id=value.__class__.__qualname__, ctx=ast.Load()), - attr=value.name, - ctx=ast.Load() - ) - - else: - warnings.warn(f"Unknown value type {type(value)}") - return ast.parse(repr(value)).body[0] - - -def _astify_outvars(state, outvars): - out = [state.name(v, ctx=ast.Store()) for v in outvars] - if len(out) == 1: - return out - else: - return [ast.Tuple(elts=out, ctx=ast.Store())] - - -def maybe_tuple_vars(vars): - if len(vars) == 1: - return vars[0] - else: - return ast.Tuple(elts=vars, ctx=ast.Load()) - - -def maybe_untuple_vars(var, is_tuple): - if is_tuple: - return ast.Starred(value=var, ctx=ast.Load()) - else: - return var - - -@register_prim_as('scan') -def _astify_scan(state, eqn): - assert eqn.primitive.name == 'scan' - - # the args to scan are [constants, carry, xs] - # constants aren't exposed in the Python API, so we need to handle them specially (we use a lambda) - num_consts = eqn.params['num_consts'] - num_carry = eqn.params['num_carry'] - - # TODO: bring back map - # if num_carry == 0: - # this is a map - # return _astify_map(eqn) - - constant_args = eqn.invars[:num_consts] - carries = eqn.invars[num_consts:num_consts + num_carry] - xs = eqn.invars[num_consts + num_carry:] - - jaxpr = eqn.params['jaxpr'] - - if num_consts != 0: - # we want to construct an environment where we partial eval the function using the constants as the env - env = dict(zip(jaxpr.jaxpr.invars, constant_args)) - jaxpr = partial_eval_jaxpr(jaxpr.jaxpr, env) - else: - jaxpr = constant_fold_jaxpr(jaxpr.jaxpr) - - fn_name = state.skolem('fn') - fn_ast = jaxpr_to_py_ast(state, jaxpr, fn_name) - - length = _astify_value(eqn.params['length']) - unroll = _astify_value(eqn.params['unroll']) - reverse = _astify_value(eqn.params['reverse']) - - stmts = [] - - if num_carry != 1 or len(jaxpr.invars) != 2: - # what we want is something like: - # fn_name = lambda carry, xs: fn_name(*carry, *xs) - # jax.lax.scan(fn_name, (carries...), (xs...)) - - modified_signature = ast.arguments( - args=[ast.arg(arg='carry'), ast.arg(arg='x')], - vararg=None, - kwonlyargs=[], - kw_defaults=[], - kwarg=None, - defaults=[], - posonlyargs=[] - ) - - initial_assign = ast.Assign( - targets=[ast.Tuple(elts=[ast.Name(a.arg) for a in fn_ast.args.args], - ctx=ast.Store())], - value=ast.Tuple( - elts=[maybe_untuple_vars(ast.Name(id='carry', ctx=ast.Load()), num_carry != 1), - maybe_untuple_vars(ast.Name(id='x', ctx=ast.Load()), len(xs) != 1)] - ) - ) - - fn_return = fn_ast.body[-1] - assert isinstance(fn_return, ast.Return) - - fn_return_value = fn_return.value - - if isinstance(fn_return_value, ast.Tuple): - fn_return_value = fn_return_value.elts - ret_carries = maybe_tuple_vars(fn_return_value[:num_carry]) - ret_ys = maybe_tuple_vars(fn_return_value[num_carry:]) - elif num_carry == 0: - ret_carries = _astify_value(()) - ret_ys = fn_return_value - else: - ret_carries = fn_return_value - ret_ys = _astify_value(()) - - scan_return = ast.Return( - value=ast.Tuple(elts=[ret_carries, ret_ys], ctx=ast.Load()) - ) - - new_body = [initial_assign] + list(fn_ast.body[:-1]) + [scan_return] - - fn_ast = ast.FunctionDef( - name=fn_name, - args=modified_signature, - body=new_body, - decorator_list=[] - ) - - stmts.append(fn_ast) - - scan_call = ast.Assign( - # targets=_astify_outvars(eqn.outvars), - targets=[ - ast.Tuple( - elts=[ast.Name(id='final_carry', ctx=ast.Store()), - ast.Name(id='ys', ctx=ast.Store())], - ctx=ast.Store() - ) - ], - value=ast.Call( - func=ast.Name(id='jax.lax.scan', ctx=ast.Load()), - args=[ast.Name(id=fn_name, ctx=ast.Load()), - maybe_tuple_vars([_astify_atom(state, v) for v in carries]), - maybe_tuple_vars([_astify_atom(state, v) for v in xs])], - keywords=[ast.keyword(arg='length', value=length), - ast.keyword(arg='unroll', value=unroll), - ast.keyword(arg='reverse', value=reverse)] - ) - ) - stmts.append(scan_call) - - if num_carry > 0: - assign_carry = ast.Assign( - targets=_astify_outvars(state, eqn.outvars[:num_carry]), - value=ast.Name(id='final_carry', ctx=ast.Load()) - ) - - stmts.append(assign_carry) - - if num_carry < len(eqn.outvars): - assign_ys = ast.Assign( - targets=_astify_outvars(state, eqn.outvars[num_carry:]), - value=ast.Name(id='ys', ctx=ast.Load()) - ) - - stmts.append(assign_ys) - else: - stmts.append(fn_ast) - - scan_call = ast.Assign( - targets=_astify_outvars(state, eqn.outvars), - value=ast.Call( - func=ast.Name(id='jax.lax.scan', ctx=ast.Load()), - args=[ast.Name(id=fn_name, ctx=ast.Load())] + [_astify_atom(state, v) for v in eqn.invars], - keywords=[ast.keyword(arg='length', value=length), - ast.keyword(arg='unroll', value=unroll), - ast.keyword(arg='reverse', value=reverse)] - ) - ) - - stmts.append(scan_call) - - return stmts - - -def _astify_map(state, eqn): - assert eqn.primitive.name == 'scan' - assert eqn.params['num_carry'] == 0 - - jaxpr = eqn.params['jaxpr'] - jaxpr = constant_fold_jaxpr(jaxpr.jaxpr) - - fn_name = state.skolem('fn') - fn_ast = jaxpr_to_py_ast(state, jaxpr, fn_name) - - # map is a bit funny, because the jaxpr takes K args, but the jax.lax.map function takes a single tuple arg - # so we need to use a lambda to redirect the call - lam = ast.parse(f"lambda args: {fn_name}(*args)").body[0] - - assign = ast.Assign( - targets=_astify_outvars(state, eqn.outvars), - value=ast.Call( - func=ast.Name(id='jax.lax.map', ctx=ast.Load()), - args=[lam, - ast.Tuple(elts=[_astify_atom(state, v) for v in eqn.invars], - ctx=ast.Load())], - keywords=[] - ) - ) - - return [fn_ast, assign] - - -@register_prim_as('closed_call') -def _astify_closed_call(state, eqn): - # out = partial_eval_jaxpr(eqn.params['call_jaxpr'].jaxpr, - # {var: val for var, val in zip(eqn.params['call_jaxpr'].jaxpr.invars, vals)}) - raw_jaxpr = eqn.params['call_jaxpr'].jaxpr - literal_args = {k: v.val - for k, v in zip(raw_jaxpr.invars, eqn.invars) - if isinstance(v, Literal)} - call_japr = partial_eval_jaxpr(raw_jaxpr, literal_args) - fn_name = state.skolem('fn') - - fn_ast = jaxpr_to_py_ast(state, call_japr, fn_name) - - invars = [_astify_atom(state, v) - for v in eqn.invars - if not isinstance(v, Literal)] - outvars = _astify_outvars(state, eqn.outvars) - - assign = ast.Assign( - targets=outvars, - value=ast.Call( - func=ast.Name(id=fn_name, ctx=ast.Load()), - args=invars, - keywords=[] - ) - ) - - return [fn_ast, assign] - - -@register_prim_as('pjit') -def _astify_pjit(state, eqn): - # this one's a real pain. - # pjit's params are : - # jaxpr - # donated_invars: - # in_shardings, out_shardings - # resource env - # name (yay) - # keep_unused, inline (which we won't use) - - jaxpr = eqn.params['jaxpr'] - donated_invars = eqn.params['donated_invars'] - in_shardings = eqn.params['in_shardings'] - out_shardings = eqn.params['out_shardings'] - resource_env = eqn.params['resource_env'] - name = eqn.params['name'] - - can_ignore_donated = not any(donated_invars) - - # preprocess the function - jaxpr = constant_fold_jaxpr(jaxpr.jaxpr) - fn_name = state.skolem(name) - fn_ast = jaxpr_to_py_ast(state, jaxpr, fn_name) - - in_shardings = _astify_value(in_shardings) - out_shardings = _astify_value(out_shardings) - - keywords = [ - ast.keyword(arg='in_shardings', value=in_shardings), - ast.keyword(arg='out_shardings', value=out_shardings), - ] - - if not can_ignore_donated: - donated_invars = _astify_value(donated_invars) - keywords.append(ast.keyword(arg='donated_invars', value=donated_invars)) - - jitted_fn = ast.Call( - func=ast.Attribute( - ast.Name(id='jax', ctx=ast.Load()), - attr='jit' - ), - args=[ast.Name(id=fn_name, ctx=ast.Load())], - keywords=keywords - ) - - assign = ast.Assign( - targets=_astify_outvars(state, eqn.outvars), - value=ast.Call( - func=jitted_fn, - args=[_astify_atom(state, v) for v in eqn.invars], - keywords=[] - ) - ) - - return [fn_ast, assign] - - -@register_prim_as('remat2') -def _astify_remat(state: SourcerorState, eqn): - # out = partial_eval_jaxpr(eqn.params['call_jaxpr'].jaxpr, - # {var: val for var, val in zip(eqn.params['call_jaxpr'].jaxpr.invars, vals)}) - call_japr = constant_fold_jaxpr(eqn.params['jaxpr']) - fn_name = state.skolem('fn') - - fn_ast = jaxpr_to_py_ast(state, call_japr, fn_name) - - invars = [_astify_atom(state, v) for v in eqn.invars] - outvars = _astify_outvars(state, eqn.outvars) - - lam = ast.Assign( - targets=[ast.Name(id=f"ckpt_{fn_name}", ctx=ast.Store())], - # value=ast.parse(f"jax.checkpoint({fn_name})").body[0] - value=ast.Call( - func=ast.Name(id='jax.checkpoint', ctx=ast.Load()), - args=[ast.Name(id=fn_name, ctx=ast.Load())], - keywords=[]) - ) - - assign = ast.Assign( - targets=outvars, - value=ast.Call( - func=ast.Name(id=f"ckpt_{fn_name}"), - args=invars, - keywords=[] - )) - - return [fn_ast, lam, assign] - - -@register_prim_as('reshape') -def _astify_reshape(state, eqn): - # the lax reshape is a bit different, because it can combine a transpose and reshape into one. - # np.reshape(np.transpose(operand, dimensions), new_sizes) - dimensions = eqn.params['dimensions'] - new_sizes = eqn.params['new_sizes'] - - source = _astify_atom(state, eqn.invars[0]) - - if dimensions is not None: - source = ast.Call( - func=ast.Name(id='jax.numpy.transpose', ctx=ast.Load()), - args=[source, _astify_value(dimensions)], - keywords=[] - ) - - assign = ast.Assign( - targets=_astify_outvars(state, eqn.outvars), - value=ast.Call( - func=ast.Name(id='jax.numpy.reshape', ctx=ast.Load()), - args=[source, _astify_value(new_sizes)], - keywords=[] - )) - - return [assign] - - -@register_prim_as('add_any') -def _astify_add_any(state, eqn): - # add_any is a weird undocumented jax primitive. best guess is it adds? - return _binop_fn(ast.Add())(state, eqn) - - -@register_prim_as('broadcast_in_dim') -def _astify_broadcast_in_dim(state, eqn): - # broadcast_in_dim is how zeros, ones, full, etc are implemented, - # so we prefer to use those where possible - assert len(eqn.invars) == 1 - value = eqn.invars[0] - shape = eqn.params['shape'] - broadcast_dimensions = eqn.params['broadcast_dimensions'] - - if not isinstance(value, Literal) or broadcast_dimensions != (): - return normal_fn('jax.lax.broadcast_in_dim')(state, eqn) - - if not isinstance(value.val, np.ndarray) or value.val.ndim != 0: - return normal_fn('jax.lax.broadcast_in_dim')(state, eqn) - else: - constant_value = value.val.item() - if constant_value == 0: - call = ast.Call( - ast.Attribute( - value=ast.Name(id='jax.numpy', ctx=ast.Load()), - attr='zeros', - ctx=ast.Load() - ), - args=[_astify_value(shape), - _astify_value(value.val.dtype)], - keywords=[] - ) - elif constant_value == 1: - call = ast.Call( - ast.Attribute( - value=ast.Name(id='jax.numpy', ctx=ast.Load()), - attr='ones', - ctx=ast.Load() - ), - args=[_astify_value(shape), - _astify_value(value.val.dtype)], - keywords=[] - ) - else: - call = ast.Call( - ast.Attribute( - value=ast.Name(id='jax.numpy', ctx=ast.Load()), - attr='full', - ctx=ast.Load() - ), - args=[_astify_value(shape), - _astify_value(constant_value), - _astify_value(value.val.dtype)], - keywords=[] - ) - - return [ast.Assign( - targets=_astify_outvars(state, eqn.outvars), - value=call - )] - - -@register_prim_as('random_wrap') -def _astify_random_wrap(state, eqn): - # we treat this as a noop - return ast.Assign( - targets=_astify_outvars(state, eqn.outvars), - value=_astify_atom(state, eqn.invars[0]) - ) - - -constant_fold_blacklist = { - 'broadcast_in_dim', - 'broadcast', -} diff --git a/brainpy/integrators/base.py b/brainpy/integrators/base.py index bdf444f7..6debbb1f 100644 --- a/brainpy/integrators/base.py +++ b/brainpy/integrators/base.py @@ -1,213 +1,213 @@ -# -*- coding: utf-8 -*- -# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -from contextlib import contextmanager -from typing import Dict, Sequence, Union, Callable - -import jax - -from brainpy._errors import DiffEqError -from brainpy.check import is_float, is_dict_data -from brainpy.math import TimeDelay, LengthDelay -from brainpy.math.object_transform.base import BrainPyObject -from ._jaxpr_to_source_code import jaxpr_to_python_code -from .constants import DT - -__all__ = [ - 'AbstractIntegrator', - 'Integrator', - 'compile_integrators', -] - - -class AbstractIntegrator(BrainPyObject): - """Basic Integrator Class.""" - - # func_name - # derivative - # code_scope - # - - def __call__(self, *args, **kwargs): - raise NotImplementedError - - -class Integrator(AbstractIntegrator): - """Basic Integrator Class.""" - - def __init__( - self, - variables: Sequence[str], - parameters: Sequence[str], - arguments: Sequence[str], - dt: float, - name: str = None, - state_delays: Dict[str, Union[TimeDelay, LengthDelay]] = None, - ): - super(Integrator, self).__init__(name=name) - - self._dt = dt - is_float(dt, 'dt', allow_none=False, allow_int=True) - self._variables = list(variables) # variables - self._parameters = list(parameters) # parameters - self._arguments = list(arguments) + [f'{DT}={self._dt}', ] # arguments - self._integral = None # integral function - self.arg_names = self._variables + self._parameters + [DT] - - # state delays - self._state_delays = dict() - if state_delays is not None: - is_dict_data(state_delays, key_type=str, val_type=(TimeDelay, LengthDelay)) - for key, delay in state_delays.items(): - if key not in self.variables: - raise DiffEqError(f'"{key}" is not defined in the variables: {self.variables}') - self._state_delays[key] = delay - self.register_implicit_nodes(self._state_delays) - - # math expression - self._math_expr = None - - @property - def dt(self): - """The numerical integration precision.""" - return self._dt - - @dt.setter - def dt(self, value): - raise ValueError('Cannot set "dt" by users.') - - @property - def variables(self): - """The variables defined in the differential equation.""" - return self._variables - - @variables.setter - def variables(self, values): - raise ValueError('Cannot set "variables" by users.') - - @property - def parameters(self): - """The parameters defined in the differential equation.""" - return self._parameters - - @parameters.setter - def parameters(self, values): - raise ValueError('Cannot set "parameters" by users.') - - @property - def arguments(self): - """All arguments when calling the numer integrator of the differential equation.""" - return self._arguments - - @arguments.setter - def arguments(self, values): - raise ValueError('Cannot set "arguments" by users.') - - @property - def integral(self): - """The integral function.""" - return self._integral - - @integral.setter - def integral(self, f): - self.set_integral(f) - - def set_integral(self, f): - """Set the integral function.""" - if not callable(f): - raise ValueError(f'integral function must be a callable function, ' - f'but we got {type(f)}: {f}') - self._integral = f - - @property - def state_delays(self): - """State delays.""" - return self._state_delays - - @state_delays.setter - def state_delays(self, value): - raise ValueError('Cannot set "state_delays" by users.') - - def _call_integral(self, *args, **kwargs): - kwargs = dict(kwargs) - t = kwargs.get('t', None) - kwargs['t'] = 0. if t is None else t - - if _during_compile: - jaxpr, out_shapes = jax.make_jaxpr(self.integral, return_shape=True)(**kwargs) - outs = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *jax.tree.leaves(kwargs)) - _, tree = jax.tree.flatten(out_shapes) - new_vars = tree.unflatten(outs) - self._math_expr = jaxpr_to_python_code(jaxpr.jaxpr) - - else: - new_vars = self.integral(**kwargs) - return new_vars - - def __call__(self, *args, **kwargs): - assert self.integral is not None, 'Please build the integrator first.' - - # check arguments - for i, arg in enumerate(args): - kwargs[self.arg_names[i]] = arg - - # integral - new_vars = self._call_integral(**kwargs) - - # post-process - if len(self.variables) == 1: - dict_vars = {self.variables[0]: new_vars} - else: - dict_vars = {k: new_vars[i] for i, k in enumerate(self.variables)} - - # update state delay variables - dt = kwargs.pop(DT, self.dt) - for key, delay in self.state_delays.items(): - if isinstance(delay, LengthDelay): - delay.update(dict_vars[key]) - elif isinstance(delay, TimeDelay): - delay.update(dict_vars[key]) - else: - raise ValueError('Unknown delay variable. We only supports ' - 'brainpy.math.LengthDelay, brainpy.math.TimeDelay. ' - f'While we got {delay}') - - return new_vars - - def to_math_expr(self): - if self._math_expr is None: - raise ValueError('Please call ``brainpy.integrators.compile_integrators`` first.') - return self._math_expr - - -_during_compile = False - - -@contextmanager -def _during_compile_context(): - global _during_compile - try: - _during_compile = True - yield - finally: - _during_compile = False - - -def compile_integrators(f: Callable, *args, **kwargs): - """ - Compile integrators in the given function. - """ - with _during_compile_context(): - return f(*args, **kwargs) +# -*- coding: utf-8 -*- +# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from contextlib import contextmanager +from typing import Dict, Sequence, Union, Callable + +import jax +from brainstate.transform import jaxpr_to_python_code + +from brainpy._errors import DiffEqError +from brainpy.check import is_float, is_dict_data +from brainpy.math import TimeDelay, LengthDelay +from brainpy.math.object_transform.base import BrainPyObject +from .constants import DT + +__all__ = [ + 'AbstractIntegrator', + 'Integrator', + 'compile_integrators', +] + + +class AbstractIntegrator(BrainPyObject): + """Basic Integrator Class.""" + + # func_name + # derivative + # code_scope + # + + def __call__(self, *args, **kwargs): + raise NotImplementedError + + +class Integrator(AbstractIntegrator): + """Basic Integrator Class.""" + + def __init__( + self, + variables: Sequence[str], + parameters: Sequence[str], + arguments: Sequence[str], + dt: float, + name: str = None, + state_delays: Dict[str, Union[TimeDelay, LengthDelay]] = None, + ): + super(Integrator, self).__init__(name=name) + + self._dt = dt + is_float(dt, 'dt', allow_none=False, allow_int=True) + self._variables = list(variables) # variables + self._parameters = list(parameters) # parameters + self._arguments = list(arguments) + [f'{DT}={self._dt}', ] # arguments + self._integral = None # integral function + self.arg_names = self._variables + self._parameters + [DT] + + # state delays + self._state_delays = dict() + if state_delays is not None: + is_dict_data(state_delays, key_type=str, val_type=(TimeDelay, LengthDelay)) + for key, delay in state_delays.items(): + if key not in self.variables: + raise DiffEqError(f'"{key}" is not defined in the variables: {self.variables}') + self._state_delays[key] = delay + self.register_implicit_nodes(self._state_delays) + + # math expression + self._math_expr = None + + @property + def dt(self): + """The numerical integration precision.""" + return self._dt + + @dt.setter + def dt(self, value): + raise ValueError('Cannot set "dt" by users.') + + @property + def variables(self): + """The variables defined in the differential equation.""" + return self._variables + + @variables.setter + def variables(self, values): + raise ValueError('Cannot set "variables" by users.') + + @property + def parameters(self): + """The parameters defined in the differential equation.""" + return self._parameters + + @parameters.setter + def parameters(self, values): + raise ValueError('Cannot set "parameters" by users.') + + @property + def arguments(self): + """All arguments when calling the numer integrator of the differential equation.""" + return self._arguments + + @arguments.setter + def arguments(self, values): + raise ValueError('Cannot set "arguments" by users.') + + @property + def integral(self): + """The integral function.""" + return self._integral + + @integral.setter + def integral(self, f): + self.set_integral(f) + + def set_integral(self, f): + """Set the integral function.""" + if not callable(f): + raise ValueError(f'integral function must be a callable function, ' + f'but we got {type(f)}: {f}') + self._integral = f + + @property + def state_delays(self): + """State delays.""" + return self._state_delays + + @state_delays.setter + def state_delays(self, value): + raise ValueError('Cannot set "state_delays" by users.') + + def _call_integral(self, *args, **kwargs): + kwargs = dict(kwargs) + t = kwargs.get('t', None) + kwargs['t'] = 0. if t is None else t + + if _during_compile: + jaxpr, out_shapes = jax.make_jaxpr(self.integral, return_shape=True)(**kwargs) + outs = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *jax.tree.leaves(kwargs)) + _, tree = jax.tree.flatten(out_shapes) + new_vars = tree.unflatten(outs) + self._math_expr = jaxpr_to_python_code(jaxpr.jaxpr) + + else: + new_vars = self.integral(**kwargs) + return new_vars + + def __call__(self, *args, **kwargs): + assert self.integral is not None, 'Please build the integrator first.' + + # check arguments + for i, arg in enumerate(args): + kwargs[self.arg_names[i]] = arg + + # integral + new_vars = self._call_integral(**kwargs) + + # post-process + if len(self.variables) == 1: + dict_vars = {self.variables[0]: new_vars} + else: + dict_vars = {k: new_vars[i] for i, k in enumerate(self.variables)} + + # update state delay variables + dt = kwargs.pop(DT, self.dt) + for key, delay in self.state_delays.items(): + if isinstance(delay, LengthDelay): + delay.update(dict_vars[key]) + elif isinstance(delay, TimeDelay): + delay.update(dict_vars[key]) + else: + raise ValueError('Unknown delay variable. We only supports ' + 'brainpy.math.LengthDelay, brainpy.math.TimeDelay. ' + f'While we got {delay}') + + return new_vars + + def to_math_expr(self): + if self._math_expr is None: + raise ValueError('Please call ``brainpy.integrators.compile_integrators`` first.') + return self._math_expr + + +_during_compile = False + + +@contextmanager +def _during_compile_context(): + global _during_compile + try: + _during_compile = True + yield + finally: + _during_compile = False + + +def compile_integrators(f: Callable, *args, **kwargs): + """ + Compile integrators in the given function. + """ + with _during_compile_context(): + return f(*args, **kwargs) diff --git a/brainpy/integrators/tests/test_to_math_expr.py b/brainpy/integrators/tests/test_to_math_expr.py index 5bd1203d..d69bf20d 100644 --- a/brainpy/integrators/tests/test_to_math_expr.py +++ b/brainpy/integrators/tests/test_to_math_expr.py @@ -1,48 +1,55 @@ -# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - - -import brainpy as bp - - -class EINet3(bp.DynSysGroup): - def __init__(self): - super().__init__() - self.N = bp.dyn.LifRef(4000, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.delay = bp.VarDelay(self.N.spike, entries={'I': None}) - self.E = bp.dyn.HalfProjAlignPostMg(comm=bp.dnn.EventJitFPHomoLinear(3200, 4000, prob=0.02, weight=0.6), - syn=bp.dyn.Expon.desc(size=4000, tau=5.), - out=bp.dyn.COBA.desc(E=0.), - post=self.N) - self.I = bp.dyn.HalfProjAlignPostMg(comm=bp.dnn.EventJitFPHomoLinear(800, 4000, prob=0.02, weight=6.7), - syn=bp.dyn.Expon.desc(size=4000, tau=10.), - out=bp.dyn.COBA.desc(E=-80.), - post=self.N) - - def update(self, input): - spk = self.delay.at('I') - self.E(spk[:3200]) - self.I(spk[3200:]) - self.delay(self.N(input)) - return self.N.spike.value - - -def test1(): - model = EINet3() - - bp.integrators.compile_integrators(model.step_run, 0, 0.) - for intg in model.nodes().subset(bp.Integrator).values(): - print(intg.to_math_expr()) +# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import brainstate +import pytest +from packaging import version + +import brainpy as bp + + +class EINet3(bp.DynSysGroup): + def __init__(self): + super().__init__() + self.N = bp.dyn.LifRef(4000, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.delay = bp.VarDelay(self.N.spike, entries={'I': None}) + self.E = bp.dyn.HalfProjAlignPostMg(comm=bp.dnn.EventJitFPHomoLinear(3200, 4000, prob=0.02, weight=0.6), + syn=bp.dyn.Expon.desc(size=4000, tau=5.), + out=bp.dyn.COBA.desc(E=0.), + post=self.N) + self.I = bp.dyn.HalfProjAlignPostMg(comm=bp.dnn.EventJitFPHomoLinear(800, 4000, prob=0.02, weight=6.7), + syn=bp.dyn.Expon.desc(size=4000, tau=10.), + out=bp.dyn.COBA.desc(E=-80.), + post=self.N) + + def update(self, input): + spk = self.delay.at('I') + self.E(spk[:3200]) + self.I(spk[3200:]) + self.delay(self.N(input)) + return self.N.spike.value + + +@pytest.mark.skipif( + tuple(map(int, brainstate.__version__.split('.'))) < (0, 3, 1), + reason='The `to_math_expr` method is only available in BrainState v0.3.1 and above.' +) +def test1(): + model = EINet3() + + bp.integrators.compile_integrators(model.step_run, 0, 0.) + for intg in model.nodes().subset(bp.Integrator).values(): + print(intg.to_math_expr()) diff --git a/brainpy/math/compat_numpy.py b/brainpy/math/compat_numpy.py index 250a447d..113a35d3 100644 --- a/brainpy/math/compat_numpy.py +++ b/brainpy/math/compat_numpy.py @@ -1,806 +1,803 @@ -# -*- coding: utf-8 -*- -# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -import jax -import jax.numpy as jnp -import numpy as np -from jax.tree_util import tree_flatten, tree_unflatten, tree_map - -from ._utils import _compatible_with_brainpy_array, _as_jax_array_ -from .interoperability import * -from .ndarray import Array - -__all__ = [ - 'full', 'full_like', 'eye', 'identity', 'diag', 'tri', 'tril', 'triu', - 'empty', 'empty_like', 'ones', 'ones_like', 'zeros', 'zeros_like', - 'array', 'asarray', 'arange', 'linspace', 'logspace', 'fill_diagonal', - - # math funcs - 'real', 'imag', 'conj', 'conjugate', 'ndim', 'isreal', 'isscalar', - 'add', 'reciprocal', 'negative', 'positive', 'multiply', 'divide', - 'power', 'subtract', 'true_divide', 'floor_divide', 'float_power', - 'fmod', 'mod', 'modf', 'divmod', 'remainder', 'abs', 'exp', 'exp2', - 'expm1', 'log', 'log10', 'log1p', 'log2', 'logaddexp', 'logaddexp2', - 'lcm', 'gcd', 'arccos', 'arccosh', 'arcsin', 'arcsinh', 'arctan', - 'arctan2', 'arctanh', 'cos', 'cosh', 'sin', 'sinc', 'sinh', 'tan', - 'tanh', 'deg2rad', 'hypot', 'rad2deg', 'degrees', 'radians', 'round', - 'around', 'round_', 'rint', 'floor', 'ceil', 'trunc', 'prod', - 'sum', 'diff', 'median', 'nancumprod', 'nancumsum', 'nanprod', 'nansum', - 'cumprod', 'cumsum', 'ediff1d', 'cross', 'isfinite', 'isinf', - 'isnan', 'signbit', 'copysign', 'nextafter', 'ldexp', 'frexp', 'convolve', - 'sqrt', 'cbrt', 'square', 'absolute', 'fabs', 'sign', 'heaviside', - 'maximum', 'minimum', 'fmax', 'fmin', 'interp', 'clip', 'angle', - - # Elementwise bit operations - 'bitwise_and', 'bitwise_not', 'bitwise_or', 'bitwise_xor', - 'invert', 'left_shift', 'right_shift', - - # logic funcs - 'equal', 'not_equal', 'greater', 'greater_equal', 'less', 'less_equal', - 'array_equal', 'isclose', 'allclose', 'logical_and', 'logical_not', - 'logical_or', 'logical_xor', 'all', 'any', "alltrue", 'sometrue', - - # array manipulation - 'shape', 'size', 'reshape', 'ravel', 'moveaxis', 'transpose', 'swapaxes', - 'concatenate', 'stack', 'vstack', 'hstack', 'dstack', 'column_stack', - 'split', 'dsplit', 'hsplit', 'vsplit', 'tile', 'repeat', 'unique', - 'append', 'flip', 'fliplr', 'flipud', 'roll', 'atleast_1d', 'atleast_2d', - 'atleast_3d', 'expand_dims', 'squeeze', 'sort', 'argsort', 'argmax', 'argmin', - 'argwhere', 'nonzero', 'flatnonzero', 'where', 'searchsorted', 'extract', - 'count_nonzero', 'max', 'min', 'amax', 'amin', - - # array creation - 'array_split', 'meshgrid', 'vander', - - # indexing funcs - 'nonzero', 'where', 'tril_indices', 'tril_indices_from', 'triu_indices', - 'triu_indices_from', 'take', 'select', - - # statistic funcs - 'nanmin', 'nanmax', 'ptp', 'percentile', 'nanpercentile', 'quantile', 'nanquantile', - 'median', 'average', 'mean', 'std', 'var', 'nanmedian', 'nanmean', 'nanstd', 'nanvar', - 'corrcoef', 'correlate', 'cov', 'histogram', 'bincount', 'digitize', - - # window funcs - 'bartlett', 'blackman', 'hamming', 'hanning', 'kaiser', - - # constants - 'e', 'pi', 'inf', - - # linear algebra - 'dot', 'vdot', 'inner', 'outer', 'kron', 'matmul', 'trace', - - # data types - 'dtype', 'finfo', 'iinfo', - - # more - 'product', 'row_stack', 'apply_over_axes', 'apply_along_axis', 'array_equiv', - 'array_repr', 'array_str', 'block', 'broadcast_arrays', 'broadcast_shapes', - 'broadcast_to', 'compress', 'cumproduct', 'diag_indices', 'diag_indices_from', - 'diagflat', 'diagonal', 'einsum', 'einsum_path', 'geomspace', 'gradient', - 'histogram2d', 'histogram_bin_edges', 'histogramdd', 'i0', 'in1d', 'indices', - 'insert', 'intersect1d', 'iscomplex', 'isin', 'ix_', 'lexsort', 'load', - 'save', 'savez', 'mask_indices', 'msort', 'nan_to_num', 'nanargmax', 'setdiff1d', - 'nanargmin', 'pad', 'poly', 'polyadd', 'polyder', 'polyfit', 'polyint', - 'polymul', 'polysub', 'polyval', 'resize', 'rollaxis', 'roots', 'rot90', - 'setxor1d', 'tensordot', 'trim_zeros', 'union1d', 'unravel_index', 'unwrap', - 'take_along_axis', 'can_cast', 'choose', 'copy', 'frombuffer', 'fromfile', - 'fromfunction', 'fromiter', 'fromstring', 'get_printoptions', 'iscomplexobj', - 'isneginf', 'isposinf', 'isrealobj', 'issubdtype', 'issubsctype', 'iterable', - 'packbits', 'piecewise', 'printoptions', 'set_printoptions', 'promote_types', - 'ravel_multi_index', 'result_type', 'sort_complex', 'unpackbits', 'delete', - - # unique - 'asanyarray', 'ascontiguousarray', 'asfarray', 'asscalar', 'common_type', 'genfromtxt', - 'loadtxt', 'info', 'place', 'polydiv', 'put', 'putmask', 'safe_eval', - 'savetxt', 'savez_compressed', 'show_config', 'typename', 'copyto', 'matrix', 'asmatrix', 'mat', - -] - -_min = min -_max = max - - -def _return(a): - return Array(a) - - -def fill_diagonal(a, val, inplace=True): - if a.ndim < 2: - raise ValueError(f'Only support tensor has dimension >= 2, but got {a.shape}') - if not isinstance(a, Array) and inplace: - raise ValueError('``fill_diagonal()`` is used in in-place updating, therefore ' - 'it requires a brainpy Array. If you want to disable ' - 'inplace updating, use ``fill_diagonal(inplace=False)``.') - val = val.value if isinstance(val, Array) else val - i, j = jnp.diag_indices(_min(a.shape[-2:])) - r = as_jax(a).at[..., i, j].set(val) - if inplace: - a.value = r - else: - return r - - -def zeros(shape, dtype=None): - return _return(jnp.zeros(shape, dtype=dtype)) - - -def ones(shape, dtype=None): - return _return(jnp.ones(shape, dtype=dtype)) - - -def empty(shape, dtype=None): - return _return(jnp.zeros(shape, dtype=dtype)) - - -def zeros_like(a, dtype=None, shape=None): - a = _as_jax_array_(a) - return _return(jnp.zeros_like(a, dtype=dtype, shape=shape)) - - -def ones_like(a, dtype=None, shape=None): - a = _as_jax_array_(a) - return _return(jnp.ones_like(a, dtype=dtype, shape=shape)) - - -def empty_like(a, dtype=None, shape=None): - a = _as_jax_array_(a) - return _return(jnp.zeros_like(a, dtype=dtype, shape=shape)) - - -def array(a, dtype=None, copy=True, order="K", ndmin=0) -> Array: - a = _as_jax_array_(a) - try: - res = jnp.array(a, dtype=dtype, copy=copy, order=order, ndmin=ndmin) - except TypeError: - leaves, tree = tree_flatten(a, is_leaf=lambda a: isinstance(a, Array)) - leaves = [_as_jax_array_(l) for l in leaves] - a = tree_unflatten(tree, leaves) - res = jnp.array(a, dtype=dtype, copy=copy, order=order, ndmin=ndmin) - return _return(res) - - -def asarray(a, dtype=None, order=None): - a = _as_jax_array_(a) - try: - res = jnp.asarray(a=a, dtype=dtype, order=order) - except TypeError: - leaves, tree = tree_flatten(a, is_leaf=lambda a: isinstance(a, Array)) - leaves = [_as_jax_array_(l) for l in leaves] - arrays = tree_unflatten(tree, leaves) - res = jnp.asarray(a=arrays, dtype=dtype, order=order) - return _return(res) - - -def arange(*args, **kwargs): - args = [_as_jax_array_(a) for a in args] - kwargs = {k: _as_jax_array_(v) for k, v in kwargs.items()} - return _return(jnp.arange(*args, **kwargs)) - - -def linspace(*args, **kwargs): - args = [_as_jax_array_(a) for a in args] - kwargs = {k: _as_jax_array_(v) for k, v in kwargs.items()} - res = jnp.linspace(*args, **kwargs) - if isinstance(res, tuple): - return _return(res[0]), res[1] - else: - return _return(res) - - -def logspace(*args, **kwargs): - args = [_as_jax_array_(a) for a in args] - kwargs = {k: _as_jax_array_(v) for k, v in kwargs.items()} - return _return(jnp.logspace(*args, **kwargs)) - - -def asanyarray(a, dtype=None, order=None): - return asarray(a, dtype=dtype, order=order) - - -def ascontiguousarray(a, dtype=None, order=None): - return asarray(a, dtype=dtype, order=order) - - -def asfarray(a, dtype=None): - if not np.issubdtype(dtype, np.inexact): - dtype = np.float64 - return asarray(a, dtype=dtype) - - -def in1d(ar1, ar2, assume_unique: bool = False, invert: bool = False) -> Array: - del assume_unique - ar1_flat = ravel(ar1) - ar2_flat = ravel(ar2) - # Note: an algorithm based on searchsorted has better scaling, but in practice - # is very slow on accelerators because it relies on lax control flow. If XLA - # ever supports binary search natively, we should switch to this: - # ar2_flat = jnp.sort(ar2_flat) - # ind = jnp.searchsorted(ar2_flat, ar1_flat) - # if invert: - # return ar1_flat != ar2_flat[ind] - # else: - # return ar1_flat == ar2_flat[ind] - if invert: - return asarray((ar1_flat[:, None] != ar2_flat[None, :]).all(-1)) - else: - return asarray((ar1_flat[:, None] == ar2_flat[None, :]).any(-1)) - - -# Others -# ------ -meshgrid = _compatible_with_brainpy_array(jnp.meshgrid) -vander = _compatible_with_brainpy_array(jnp.vander) -full = _compatible_with_brainpy_array(jnp.full) -full_like = _compatible_with_brainpy_array(jnp.full_like) -eye = _compatible_with_brainpy_array(jnp.eye) -identity = _compatible_with_brainpy_array(jnp.identity) -diag = _compatible_with_brainpy_array(jnp.diag) -tri = _compatible_with_brainpy_array(jnp.tri) -tril = _compatible_with_brainpy_array(jnp.tril) -triu = _compatible_with_brainpy_array(jnp.triu) -delete = _compatible_with_brainpy_array(jnp.delete) -take_along_axis = _compatible_with_brainpy_array(jnp.take_along_axis) -block = _compatible_with_brainpy_array(jnp.block) -broadcast_arrays = _compatible_with_brainpy_array(jnp.broadcast_arrays) -broadcast_shapes = _compatible_with_brainpy_array(jnp.broadcast_shapes) -broadcast_to = _compatible_with_brainpy_array(jnp.broadcast_to) -compress = _compatible_with_brainpy_array(jnp.compress) -diag_indices = _compatible_with_brainpy_array(jnp.diag_indices) -diag_indices_from = _compatible_with_brainpy_array(jnp.diag_indices_from) -diagflat = _compatible_with_brainpy_array(jnp.diagflat) -diagonal = _compatible_with_brainpy_array(jnp.diagonal) -einsum = _compatible_with_brainpy_array(jnp.einsum) -einsum_path = _compatible_with_brainpy_array(jnp.einsum_path) -geomspace = _compatible_with_brainpy_array(jnp.geomspace) -gradient = _compatible_with_brainpy_array(jnp.gradient) -histogram2d = _compatible_with_brainpy_array(jnp.histogram2d) -histogram_bin_edges = _compatible_with_brainpy_array(jnp.histogram_bin_edges) -histogramdd = _compatible_with_brainpy_array(jnp.histogramdd) -i0 = _compatible_with_brainpy_array(jnp.i0) -indices = _compatible_with_brainpy_array(jnp.indices) -insert = _compatible_with_brainpy_array(jnp.insert) -intersect1d = _compatible_with_brainpy_array(jnp.intersect1d) -iscomplex = _compatible_with_brainpy_array(jnp.iscomplex) -isin = _compatible_with_brainpy_array(jnp.isin) -ix_ = _compatible_with_brainpy_array(jnp.ix_) -lexsort = _compatible_with_brainpy_array(jnp.lexsort) -load = _compatible_with_brainpy_array(jnp.load) -save = _compatible_with_brainpy_array(jnp.save) -savez = _compatible_with_brainpy_array(jnp.savez) -mask_indices = _compatible_with_brainpy_array(jnp.mask_indices) - - -def msort(a): - """ - Return a copy of an array sorted along the first axis. - - Parameters:: - - a : array_like - Array to be sorted. - - Returns:: - - sorted_array : ndarray - Array of the same type and shape as `a`. - - See Also:: - - sort - - Notes:: - - ``brainpy.math.msort(a)`` is equivalent to ``brainpy.math.sort(a, axis=0)``. - - """ - return sort(a, axis=0) - - -nan_to_num = _compatible_with_brainpy_array(jnp.nan_to_num) -nanargmax = _compatible_with_brainpy_array(jnp.nanargmax) -nanargmin = _compatible_with_brainpy_array(jnp.nanargmin) -pad = _compatible_with_brainpy_array(jnp.pad) -poly = _compatible_with_brainpy_array(jnp.poly) -polyadd = _compatible_with_brainpy_array(jnp.polyadd) -polyder = _compatible_with_brainpy_array(jnp.polyder) -polyfit = _compatible_with_brainpy_array(jnp.polyfit) -polyint = _compatible_with_brainpy_array(jnp.polyint) -polymul = _compatible_with_brainpy_array(jnp.polymul) -polysub = _compatible_with_brainpy_array(jnp.polysub) -polyval = _compatible_with_brainpy_array(jnp.polyval) -resize = _compatible_with_brainpy_array(jnp.resize) -rollaxis = _compatible_with_brainpy_array(jnp.rollaxis) -roots = _compatible_with_brainpy_array(jnp.roots) -rot90 = _compatible_with_brainpy_array(jnp.rot90) -setdiff1d = _compatible_with_brainpy_array(jnp.setdiff1d) -setxor1d = _compatible_with_brainpy_array(jnp.setxor1d) -tensordot = _compatible_with_brainpy_array(jnp.tensordot) -trim_zeros = _compatible_with_brainpy_array(jnp.trim_zeros) -union1d = _compatible_with_brainpy_array(jnp.union1d) -unravel_index = _compatible_with_brainpy_array(jnp.unravel_index) -unwrap = _compatible_with_brainpy_array(jnp.unwrap) - -# math funcs -# ---------- -isreal = _compatible_with_brainpy_array(jnp.isreal) -isscalar = _compatible_with_brainpy_array(jnp.isscalar) -real = _compatible_with_brainpy_array(jnp.real) -imag = _compatible_with_brainpy_array(jnp.imag) -conj = _compatible_with_brainpy_array(jnp.conj) -conjugate = _compatible_with_brainpy_array(jnp.conjugate) -ndim = _compatible_with_brainpy_array(jnp.ndim) -add = _compatible_with_brainpy_array(jnp.add) -reciprocal = _compatible_with_brainpy_array(jnp.reciprocal) -negative = _compatible_with_brainpy_array(jnp.negative) -positive = _compatible_with_brainpy_array(jnp.positive) -multiply = _compatible_with_brainpy_array(jnp.multiply) -divide = _compatible_with_brainpy_array(jnp.divide) -power = _compatible_with_brainpy_array(jnp.power) -subtract = _compatible_with_brainpy_array(jnp.subtract) -true_divide = _compatible_with_brainpy_array(jnp.true_divide) -floor_divide = _compatible_with_brainpy_array(jnp.floor_divide) -float_power = _compatible_with_brainpy_array(jnp.float_power) -fmod = _compatible_with_brainpy_array(jnp.fmod) -mod = _compatible_with_brainpy_array(jnp.mod) -divmod = _compatible_with_brainpy_array(jnp.divmod) -remainder = _compatible_with_brainpy_array(jnp.remainder) -modf = _compatible_with_brainpy_array(jnp.modf) -abs = _compatible_with_brainpy_array(jnp.abs) -absolute = _compatible_with_brainpy_array(jnp.absolute) -exp = _compatible_with_brainpy_array(jnp.exp) -exp2 = _compatible_with_brainpy_array(jnp.exp2) -expm1 = _compatible_with_brainpy_array(jnp.expm1) -log = _compatible_with_brainpy_array(jnp.log) -log10 = _compatible_with_brainpy_array(jnp.log10) -log1p = _compatible_with_brainpy_array(jnp.log1p) -log2 = _compatible_with_brainpy_array(jnp.log2) -logaddexp = _compatible_with_brainpy_array(jnp.logaddexp) -logaddexp2 = _compatible_with_brainpy_array(jnp.logaddexp2) -lcm = _compatible_with_brainpy_array(jnp.lcm) -gcd = _compatible_with_brainpy_array(jnp.gcd) -arccos = _compatible_with_brainpy_array(jnp.arccos) -arccosh = _compatible_with_brainpy_array(jnp.arccosh) -arcsin = _compatible_with_brainpy_array(jnp.arcsin) -arcsinh = _compatible_with_brainpy_array(jnp.arcsinh) -arctan = _compatible_with_brainpy_array(jnp.arctan) -arctan2 = _compatible_with_brainpy_array(jnp.arctan2) -arctanh = _compatible_with_brainpy_array(jnp.arctanh) -cos = _compatible_with_brainpy_array(jnp.cos) -cosh = _compatible_with_brainpy_array(jnp.cosh) -sin = _compatible_with_brainpy_array(jnp.sin) -sinc = _compatible_with_brainpy_array(jnp.sinc) -sinh = _compatible_with_brainpy_array(jnp.sinh) -tan = _compatible_with_brainpy_array(jnp.tan) -tanh = _compatible_with_brainpy_array(jnp.tanh) -deg2rad = _compatible_with_brainpy_array(jnp.deg2rad) -rad2deg = _compatible_with_brainpy_array(jnp.rad2deg) -degrees = _compatible_with_brainpy_array(jnp.degrees) -radians = _compatible_with_brainpy_array(jnp.radians) -hypot = _compatible_with_brainpy_array(jnp.hypot) -round = _compatible_with_brainpy_array(jnp.round) -around = round -round_ = round -rint = _compatible_with_brainpy_array(jnp.rint) -floor = _compatible_with_brainpy_array(jnp.floor) -ceil = _compatible_with_brainpy_array(jnp.ceil) -trunc = _compatible_with_brainpy_array(jnp.trunc) -prod = _compatible_with_brainpy_array(jnp.prod) - -sum = _compatible_with_brainpy_array(jnp.sum) - -diff = _compatible_with_brainpy_array(jnp.diff) -median = _compatible_with_brainpy_array(jnp.median) -nancumprod = _compatible_with_brainpy_array(jnp.nancumprod) -nancumsum = _compatible_with_brainpy_array(jnp.nancumsum) -cumprod = _compatible_with_brainpy_array(jnp.cumprod) -cumproduct = cumprod -cumsum = _compatible_with_brainpy_array(jnp.cumsum) -nanprod = _compatible_with_brainpy_array(jnp.nanprod) -nansum = _compatible_with_brainpy_array(jnp.nansum) -ediff1d = _compatible_with_brainpy_array(jnp.ediff1d) -cross = _compatible_with_brainpy_array(jnp.cross) -if jax.__version__ >= '0.4.18': - trapz = _compatible_with_brainpy_array(jax.scipy.integrate.trapezoid) -else: - trapz = _compatible_with_brainpy_array(jnp.trapezoid) -isfinite = _compatible_with_brainpy_array(jnp.isfinite) -isinf = _compatible_with_brainpy_array(jnp.isinf) -isnan = _compatible_with_brainpy_array(jnp.isnan) -signbit = _compatible_with_brainpy_array(jnp.signbit) -nextafter = _compatible_with_brainpy_array(jnp.nextafter) -copysign = _compatible_with_brainpy_array(jnp.copysign) -ldexp = _compatible_with_brainpy_array(jnp.ldexp) -frexp = _compatible_with_brainpy_array(jnp.frexp) -convolve = _compatible_with_brainpy_array(jnp.convolve) -sqrt = _compatible_with_brainpy_array(jnp.sqrt) -cbrt = _compatible_with_brainpy_array(jnp.cbrt) -square = _compatible_with_brainpy_array(jnp.square) -fabs = _compatible_with_brainpy_array(jnp.fabs) -sign = _compatible_with_brainpy_array(jnp.sign) -heaviside = _compatible_with_brainpy_array(jnp.heaviside) -maximum = _compatible_with_brainpy_array(jnp.maximum) -minimum = _compatible_with_brainpy_array(jnp.minimum) -fmax = _compatible_with_brainpy_array(jnp.fmax) -fmin = _compatible_with_brainpy_array(jnp.fmin) -interp = _compatible_with_brainpy_array(jnp.interp) -clip = _compatible_with_brainpy_array(jnp.clip) -angle = _compatible_with_brainpy_array(jnp.angle) -bitwise_not = _compatible_with_brainpy_array(jnp.bitwise_not) -invert = _compatible_with_brainpy_array(jnp.invert) -bitwise_and = _compatible_with_brainpy_array(jnp.bitwise_and) -bitwise_or = _compatible_with_brainpy_array(jnp.bitwise_or) -bitwise_xor = _compatible_with_brainpy_array(jnp.bitwise_xor) -left_shift = _compatible_with_brainpy_array(jnp.left_shift) -right_shift = _compatible_with_brainpy_array(jnp.right_shift) -equal = _compatible_with_brainpy_array(jnp.equal) -not_equal = _compatible_with_brainpy_array(jnp.not_equal) -greater = _compatible_with_brainpy_array(jnp.greater) -greater_equal = _compatible_with_brainpy_array(jnp.greater_equal) -less = _compatible_with_brainpy_array(jnp.less) -less_equal = _compatible_with_brainpy_array(jnp.less_equal) -array_equal = _compatible_with_brainpy_array(jnp.array_equal) -isclose = _compatible_with_brainpy_array(jnp.isclose) -allclose = _compatible_with_brainpy_array(jnp.allclose) -logical_not = _compatible_with_brainpy_array(jnp.logical_not) -logical_and = _compatible_with_brainpy_array(jnp.logical_and) -logical_or = _compatible_with_brainpy_array(jnp.logical_or) -logical_xor = _compatible_with_brainpy_array(jnp.logical_xor) -all = _compatible_with_brainpy_array(jnp.all) -any = _compatible_with_brainpy_array(jnp.any) - -alltrue = all -sometrue = any - - -def shape(a): - """ - Return the shape of an array. - - Parameters:: - - a : array_like - Input array. - - Returns:: - - shape : tuple of ints - The elements of the shape tuple give the lengths of the - corresponding array dimensions. - - See Also:: - - len : ``len(a)`` is equivalent to ``np.shape(a)[0]`` for N-D arrays with - ``N>=1``. - ndarray.shape : Equivalent array method. - - Examples:: - - >>> import brainpy - >>> brainpy.math.shape(brainpy.math.eye(3)) - (3, 3) - >>> brainpy.math.shape([[1, 3]]) - (1, 2) - >>> brainpy.math.shape([0]) - (1,) - >>> brainpy.math.shape(0) - () - - """ - if isinstance(a, (Array, jax.Array, np.ndarray)): - return a.shape - else: - return np.shape(a) - - -def size(a, axis=None): - """ - Return the number of elements along a given axis. - - Parameters:: - - a : array_like - Input data. - axis : int, optional - Axis along which the elements are counted. By default, give - the total number of elements. - - Returns:: - - element_count : int - Number of elements along the specified axis. - - See Also:: - - shape : dimensions of array - Array.shape : dimensions of array - Array.size : number of elements in array - - Examples:: - - >>> import brainpy - >>> a = brainpy.math.array([[1,2,3], [4,5,6]]) - >>> brainpy.math.size(a) - 6 - >>> brainpy.math.size(a, 1) - 3 - >>> brainpy.math.size(a, 0) - 2 - """ - if isinstance(a, (Array, jax.Array, np.ndarray)): - if axis is None: - return a.size - else: - return a.shape[axis] - else: - return np.size(a, axis=axis) - - -reshape = _compatible_with_brainpy_array(jnp.reshape) -ravel = _compatible_with_brainpy_array(jnp.ravel) -moveaxis = _compatible_with_brainpy_array(jnp.moveaxis) -transpose = _compatible_with_brainpy_array(jnp.transpose) -swapaxes = _compatible_with_brainpy_array(jnp.swapaxes) -concatenate = _compatible_with_brainpy_array(jnp.concatenate) -stack = _compatible_with_brainpy_array(jnp.stack) -vstack = _compatible_with_brainpy_array(jnp.vstack) -product = prod -row_stack = vstack -hstack = _compatible_with_brainpy_array(jnp.hstack) -dstack = _compatible_with_brainpy_array(jnp.dstack) -column_stack = _compatible_with_brainpy_array(jnp.column_stack) -split = _compatible_with_brainpy_array(jnp.split) -dsplit = _compatible_with_brainpy_array(jnp.dsplit) -hsplit = _compatible_with_brainpy_array(jnp.hsplit) -vsplit = _compatible_with_brainpy_array(jnp.vsplit) -tile = _compatible_with_brainpy_array(jnp.tile) -repeat = _compatible_with_brainpy_array(jnp.repeat) -unique = _compatible_with_brainpy_array(jnp.unique) -append = _compatible_with_brainpy_array(jnp.append) -flip = _compatible_with_brainpy_array(jnp.flip) -fliplr = _compatible_with_brainpy_array(jnp.fliplr) -flipud = _compatible_with_brainpy_array(jnp.flipud) -roll = _compatible_with_brainpy_array(jnp.roll) -atleast_1d = _compatible_with_brainpy_array(jnp.atleast_1d) -atleast_2d = _compatible_with_brainpy_array(jnp.atleast_2d) -atleast_3d = _compatible_with_brainpy_array(jnp.atleast_3d) -expand_dims = _compatible_with_brainpy_array(jnp.expand_dims) -squeeze = _compatible_with_brainpy_array(jnp.squeeze) -sort = _compatible_with_brainpy_array(jnp.sort) -argsort = _compatible_with_brainpy_array(jnp.argsort) -argmax = _compatible_with_brainpy_array(jnp.argmax) -argmin = _compatible_with_brainpy_array(jnp.argmin) -argwhere = _compatible_with_brainpy_array(jnp.argwhere) -nonzero = _compatible_with_brainpy_array(jnp.nonzero) -flatnonzero = _compatible_with_brainpy_array(jnp.flatnonzero) -where = _compatible_with_brainpy_array(jnp.where) -searchsorted = _compatible_with_brainpy_array(jnp.searchsorted) -extract = _compatible_with_brainpy_array(jnp.extract) -count_nonzero = _compatible_with_brainpy_array(jnp.count_nonzero) -max = _compatible_with_brainpy_array(jnp.max) - -min = _compatible_with_brainpy_array(jnp.min) - -amax = max -amin = min -apply_along_axis = _compatible_with_brainpy_array(jnp.apply_along_axis) -apply_over_axes = _compatible_with_brainpy_array(jnp.apply_over_axes) -array_equiv = _compatible_with_brainpy_array(jnp.array_equiv) -array_repr = _compatible_with_brainpy_array(jnp.array_repr) -array_str = _compatible_with_brainpy_array(jnp.array_str) -array_split = _compatible_with_brainpy_array(jnp.array_split) - -# indexing funcs -# -------------- - -tril_indices = jnp.tril_indices -triu_indices = jnp.triu_indices -tril_indices_from = _compatible_with_brainpy_array(jnp.tril_indices_from) -triu_indices_from = _compatible_with_brainpy_array(jnp.triu_indices_from) -take = _compatible_with_brainpy_array(jnp.take) -select = _compatible_with_brainpy_array(jnp.select) -nanmin = _compatible_with_brainpy_array(jnp.nanmin) -nanmax = _compatible_with_brainpy_array(jnp.nanmax) -ptp = _compatible_with_brainpy_array(jnp.ptp) -percentile = _compatible_with_brainpy_array(jnp.percentile) -nanpercentile = _compatible_with_brainpy_array(jnp.nanpercentile) -quantile = _compatible_with_brainpy_array(jnp.quantile) -nanquantile = _compatible_with_brainpy_array(jnp.nanquantile) -average = _compatible_with_brainpy_array(jnp.average) -mean = _compatible_with_brainpy_array(jnp.mean) -std = _compatible_with_brainpy_array(jnp.std) -var = _compatible_with_brainpy_array(jnp.var) -nanmedian = _compatible_with_brainpy_array(jnp.nanmedian) -nanmean = _compatible_with_brainpy_array(jnp.nanmean) -nanstd = _compatible_with_brainpy_array(jnp.nanstd) -nanvar = _compatible_with_brainpy_array(jnp.nanvar) -corrcoef = _compatible_with_brainpy_array(jnp.corrcoef) -correlate = _compatible_with_brainpy_array(jnp.correlate) -cov = _compatible_with_brainpy_array(jnp.cov) -histogram = _compatible_with_brainpy_array(jnp.histogram) -bincount = _compatible_with_brainpy_array(jnp.bincount) -digitize = _compatible_with_brainpy_array(jnp.digitize) -bartlett = _compatible_with_brainpy_array(jnp.bartlett) -blackman = _compatible_with_brainpy_array(jnp.blackman) -hamming = _compatible_with_brainpy_array(jnp.hamming) -hanning = _compatible_with_brainpy_array(jnp.hanning) -kaiser = _compatible_with_brainpy_array(jnp.kaiser) - -# constants -# --------- - -e = jnp.e -pi = jnp.pi -inf = jnp.inf - -# linear algebra -# -------------- - -dot = _compatible_with_brainpy_array(jnp.dot) -vdot = _compatible_with_brainpy_array(jnp.vdot) -inner = _compatible_with_brainpy_array(jnp.inner) -outer = _compatible_with_brainpy_array(jnp.outer) -kron = _compatible_with_brainpy_array(jnp.kron) -matmul = _compatible_with_brainpy_array(jnp.matmul) -trace = _compatible_with_brainpy_array(jnp.trace) - -dtype = jnp.dtype -finfo = jnp.finfo -iinfo = jnp.iinfo - -can_cast = _compatible_with_brainpy_array(jnp.can_cast) -choose = _compatible_with_brainpy_array(jnp.choose) -copy = _compatible_with_brainpy_array(jnp.copy) -frombuffer = _compatible_with_brainpy_array(jnp.frombuffer) -fromfile = _compatible_with_brainpy_array(jnp.fromfile) -fromfunction = _compatible_with_brainpy_array(jnp.fromfunction) -fromiter = _compatible_with_brainpy_array(jnp.fromiter) -fromstring = _compatible_with_brainpy_array(jnp.fromstring) -get_printoptions = np.get_printoptions -iscomplexobj = _compatible_with_brainpy_array(jnp.iscomplexobj) -isneginf = _compatible_with_brainpy_array(jnp.isneginf) -isposinf = _compatible_with_brainpy_array(jnp.isposinf) -isrealobj = _compatible_with_brainpy_array(jnp.isrealobj) -issubdtype = jnp.issubdtype -issubsctype = jnp.issubdtype -iterable = _compatible_with_brainpy_array(jnp.iterable) -packbits = _compatible_with_brainpy_array(jnp.packbits) -piecewise = _compatible_with_brainpy_array(jnp.piecewise) -printoptions = np.printoptions -set_printoptions = np.set_printoptions -promote_types = _compatible_with_brainpy_array(jnp.promote_types) -ravel_multi_index = _compatible_with_brainpy_array(jnp.ravel_multi_index) -result_type = _compatible_with_brainpy_array(jnp.result_type) -sort_complex = _compatible_with_brainpy_array(jnp.sort_complex) -unpackbits = _compatible_with_brainpy_array(jnp.unpackbits) - - -# Unique APIs -# ----------- - - -def asscalar(a): - return a.item() - - -array_type = [[np.half, np.single, np.double, np.longdouble], - [None, np.csingle, np.cdouble, np.clongdouble]] -array_precision = {np.half: 0, - np.single: 1, - np.double: 2, - np.longdouble: 3, - np.csingle: 1, - np.cdouble: 2, - np.clongdouble: 3} - - -def common_type(*arrays): - is_complex = False - precision = 0 - for a in arrays: - t = a.dtype.type - if iscomplexobj(a): - is_complex = True - if issubclass(t, jnp.integer): - p = 2 # array_precision[_nx.double] - else: - p = array_precision.get(t, None) - if p is None: - raise TypeError("can't get common type for non-numeric array") - precision = _max(precision, p) - if is_complex: - return array_type[1][precision] - else: - return array_type[0][precision] - - -genfromtxt = lambda *args, **kwargs: asarray(np.genfromtxt(*args, **kwargs)) -loadtxt = lambda *args, **kwargs: asarray(np.loadtxt(*args, **kwargs)) -info = np.info - - -def place(arr, mask, vals): - if not isinstance(arr, Array): - raise ValueError(f'Must be an instance of brainpy Array, but we got {type(arr)}') - arr[mask] = vals - - -polydiv = _compatible_with_brainpy_array(jnp.polydiv) - - -def put(a, ind, v): - if not isinstance(a, Array): - raise ValueError(f'Must be an instance of brainpy Array, but we got {type(a)}') - a[ind] = v - - -def putmask(a, mask, values): - if not isinstance(a, Array): - raise ValueError(f'Must be an instance of brainpy Array, but we got {type(a)}') - if a.shape != values.shape: - raise ValueError('Only support the shapes of "a" and "values" are consistent.') - a[mask] = values - - -def safe_eval(source): - return tree_map(Array, np.safe_eval(source)) - - -def savetxt(fname, X, fmt='%.18e', delimiter=' ', newline='\n', header='', - footer='', comments='# ', encoding=None): - X = as_numpy(X) - np.savetxt(fname, X, fmt=fmt, delimiter=delimiter, newline=newline, header=header, - footer=footer, comments=comments, encoding=encoding) - - -def savez_compressed(file, *args, **kwds): - args = tuple([(as_numpy(a) if isinstance(a, (jnp.ndarray, Array)) else a) for a in args]) - kwds = {k: (as_numpy(v) if isinstance(v, (jnp.ndarray, Array)) else v) - for k, v in kwds.items()} - np.savez_compressed(file, *args, **kwds) - - -show_config = np.show_config -typename = np.typename - - -def copyto(dst, src): - if not isinstance(dst, Array): - raise ValueError('dst must be an instance of ArrayType.') - dst[:] = src - - -def matrix(data, dtype=None): - data = array(data, copy=True, dtype=dtype) - if data.ndim > 2: - raise ValueError(f'shape too large {data.shape} to be a matrix.') - if data.ndim != 2: - for i in range(2 - data.ndim): - data = expand_dims(data, 0) - return data - - -def asmatrix(data, dtype=None): - data = array(data, dtype=dtype) - if data.ndim > 2: - raise ValueError(f'shape too large {data.shape} to be a matrix.') - if data.ndim != 2: - for i in range(2 - data.ndim): - data = expand_dims(data, 0) - return data - - -def mat(data, dtype=None): - return asmatrix(data, dtype=dtype) +# -*- coding: utf-8 -*- +# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import jax +import jax.numpy as jnp +import numpy as np +from jax.tree_util import tree_flatten, tree_unflatten, tree_map + +from ._utils import _compatible_with_brainpy_array, _as_jax_array_ +from .interoperability import * +from .ndarray import Array + +__all__ = [ + 'full', 'full_like', 'eye', 'identity', 'diag', 'tri', 'tril', 'triu', + 'empty', 'empty_like', 'ones', 'ones_like', 'zeros', 'zeros_like', + 'array', 'asarray', 'arange', 'linspace', 'logspace', 'fill_diagonal', + + # math funcs + 'real', 'imag', 'conj', 'conjugate', 'ndim', 'isreal', 'isscalar', + 'add', 'reciprocal', 'negative', 'positive', 'multiply', 'divide', + 'power', 'subtract', 'true_divide', 'floor_divide', 'float_power', + 'fmod', 'mod', 'modf', 'divmod', 'remainder', 'abs', 'exp', 'exp2', + 'expm1', 'log', 'log10', 'log1p', 'log2', 'logaddexp', 'logaddexp2', + 'lcm', 'gcd', 'arccos', 'arccosh', 'arcsin', 'arcsinh', 'arctan', + 'arctan2', 'arctanh', 'cos', 'cosh', 'sin', 'sinc', 'sinh', 'tan', + 'tanh', 'deg2rad', 'hypot', 'rad2deg', 'degrees', 'radians', 'round', + 'around', 'round_', 'rint', 'floor', 'ceil', 'trunc', 'prod', + 'sum', 'diff', 'median', 'nancumprod', 'nancumsum', 'nanprod', 'nansum', + 'cumprod', 'cumsum', 'ediff1d', 'cross', 'isfinite', 'isinf', + 'isnan', 'signbit', 'copysign', 'nextafter', 'ldexp', 'frexp', 'convolve', + 'sqrt', 'cbrt', 'square', 'absolute', 'fabs', 'sign', 'heaviside', + 'maximum', 'minimum', 'fmax', 'fmin', 'interp', 'clip', 'angle', + + # Elementwise bit operations + 'bitwise_and', 'bitwise_not', 'bitwise_or', 'bitwise_xor', + 'invert', 'left_shift', 'right_shift', + + # logic funcs + 'equal', 'not_equal', 'greater', 'greater_equal', 'less', 'less_equal', + 'array_equal', 'isclose', 'allclose', 'logical_and', 'logical_not', + 'logical_or', 'logical_xor', 'all', 'any', "alltrue", 'sometrue', + + # array manipulation + 'shape', 'size', 'reshape', 'ravel', 'moveaxis', 'transpose', 'swapaxes', + 'concatenate', 'stack', 'vstack', 'hstack', 'dstack', 'column_stack', + 'split', 'dsplit', 'hsplit', 'vsplit', 'tile', 'repeat', 'unique', + 'append', 'flip', 'fliplr', 'flipud', 'roll', 'atleast_1d', 'atleast_2d', + 'atleast_3d', 'expand_dims', 'squeeze', 'sort', 'argsort', 'argmax', 'argmin', + 'argwhere', 'nonzero', 'flatnonzero', 'where', 'searchsorted', 'extract', + 'count_nonzero', 'max', 'min', 'amax', 'amin', + + # array creation + 'array_split', 'meshgrid', 'vander', + + # indexing funcs + 'nonzero', 'where', 'tril_indices', 'tril_indices_from', 'triu_indices', + 'triu_indices_from', 'take', 'select', + + # statistic funcs + 'nanmin', 'nanmax', 'ptp', 'percentile', 'nanpercentile', 'quantile', 'nanquantile', + 'median', 'average', 'mean', 'std', 'var', 'nanmedian', 'nanmean', 'nanstd', 'nanvar', + 'corrcoef', 'correlate', 'cov', 'histogram', 'bincount', 'digitize', + + # window funcs + 'bartlett', 'blackman', 'hamming', 'hanning', 'kaiser', + + # constants + 'e', 'pi', 'inf', + + # linear algebra + 'dot', 'vdot', 'inner', 'outer', 'kron', 'matmul', 'trace', + + # data types + 'dtype', 'finfo', 'iinfo', + + # more + 'product', 'row_stack', 'apply_over_axes', 'apply_along_axis', 'array_equiv', + 'array_repr', 'array_str', 'block', 'broadcast_arrays', 'broadcast_shapes', + 'broadcast_to', 'compress', 'cumproduct', 'diag_indices', 'diag_indices_from', + 'diagflat', 'diagonal', 'einsum', 'einsum_path', 'geomspace', 'gradient', + 'histogram2d', 'histogram_bin_edges', 'histogramdd', 'i0', 'in1d', 'indices', + 'insert', 'intersect1d', 'iscomplex', 'isin', 'ix_', 'lexsort', 'load', + 'save', 'savez', 'mask_indices', 'msort', 'nan_to_num', 'nanargmax', 'setdiff1d', + 'nanargmin', 'pad', 'poly', 'polyadd', 'polyder', 'polyfit', 'polyint', + 'polymul', 'polysub', 'polyval', 'resize', 'rollaxis', 'roots', 'rot90', + 'setxor1d', 'tensordot', 'trim_zeros', 'union1d', 'unravel_index', 'unwrap', + 'take_along_axis', 'can_cast', 'choose', 'copy', 'frombuffer', 'fromfile', + 'fromfunction', 'fromiter', 'fromstring', 'get_printoptions', 'iscomplexobj', + 'isneginf', 'isposinf', 'isrealobj', 'issubdtype', 'issubsctype', 'iterable', + 'packbits', 'piecewise', 'printoptions', 'set_printoptions', 'promote_types', + 'ravel_multi_index', 'result_type', 'sort_complex', 'unpackbits', 'delete', + + # unique + 'asanyarray', 'ascontiguousarray', 'asfarray', 'asscalar', 'common_type', 'genfromtxt', + 'loadtxt', 'info', 'place', 'polydiv', 'put', 'putmask', 'safe_eval', + 'savetxt', 'savez_compressed', 'show_config', 'typename', 'copyto', 'matrix', 'asmatrix', 'mat', + +] + +_min = min +_max = max + + +def _return(a): + return Array(a) + + +def fill_diagonal(a, val, inplace=True): + if a.ndim < 2: + raise ValueError(f'Only support tensor has dimension >= 2, but got {a.shape}') + if not isinstance(a, Array) and inplace: + raise ValueError('``fill_diagonal()`` is used in in-place updating, therefore ' + 'it requires a brainpy Array. If you want to disable ' + 'inplace updating, use ``fill_diagonal(inplace=False)``.') + val = val.value if isinstance(val, Array) else val + i, j = jnp.diag_indices(_min(a.shape[-2:])) + r = as_jax(a).at[..., i, j].set(val) + if inplace: + a.value = r + else: + return r + + +def zeros(shape, dtype=None): + return _return(jnp.zeros(shape, dtype=dtype)) + + +def ones(shape, dtype=None): + return _return(jnp.ones(shape, dtype=dtype)) + + +def empty(shape, dtype=None): + return _return(jnp.zeros(shape, dtype=dtype)) + + +def zeros_like(a, dtype=None, shape=None): + a = _as_jax_array_(a) + return _return(jnp.zeros_like(a, dtype=dtype, shape=shape)) + + +def ones_like(a, dtype=None, shape=None): + a = _as_jax_array_(a) + return _return(jnp.ones_like(a, dtype=dtype, shape=shape)) + + +def empty_like(a, dtype=None, shape=None): + a = _as_jax_array_(a) + return _return(jnp.zeros_like(a, dtype=dtype, shape=shape)) + + +def array(a, dtype=None, copy=True, order="K", ndmin=0) -> Array: + a = _as_jax_array_(a) + try: + res = jnp.array(a, dtype=dtype, copy=copy, order=order, ndmin=ndmin) + except TypeError: + leaves, tree = tree_flatten(a, is_leaf=lambda a: isinstance(a, Array)) + leaves = [_as_jax_array_(l) for l in leaves] + a = tree_unflatten(tree, leaves) + res = jnp.array(a, dtype=dtype, copy=copy, order=order, ndmin=ndmin) + return _return(res) + + +def asarray(a, dtype=None, order=None): + a = _as_jax_array_(a) + try: + res = jnp.asarray(a=a, dtype=dtype, order=order) + except TypeError: + leaves, tree = tree_flatten(a, is_leaf=lambda a: isinstance(a, Array)) + leaves = [_as_jax_array_(l) for l in leaves] + arrays = tree_unflatten(tree, leaves) + res = jnp.asarray(a=arrays, dtype=dtype, order=order) + return _return(res) + + +def arange(*args, **kwargs): + args = [_as_jax_array_(a) for a in args] + kwargs = {k: _as_jax_array_(v) for k, v in kwargs.items()} + return _return(jnp.arange(*args, **kwargs)) + + +def linspace(*args, **kwargs): + args = [_as_jax_array_(a) for a in args] + kwargs = {k: _as_jax_array_(v) for k, v in kwargs.items()} + res = jnp.linspace(*args, **kwargs) + if isinstance(res, tuple): + return _return(res[0]), res[1] + else: + return _return(res) + + +def logspace(*args, **kwargs): + args = [_as_jax_array_(a) for a in args] + kwargs = {k: _as_jax_array_(v) for k, v in kwargs.items()} + return _return(jnp.logspace(*args, **kwargs)) + + +def asanyarray(a, dtype=None, order=None): + return asarray(a, dtype=dtype, order=order) + + +def ascontiguousarray(a, dtype=None, order=None): + return asarray(a, dtype=dtype, order=order) + + +def asfarray(a, dtype=None): + if not np.issubdtype(dtype, np.inexact): + dtype = np.float64 + return asarray(a, dtype=dtype) + + +def in1d(ar1, ar2, assume_unique: bool = False, invert: bool = False) -> Array: + del assume_unique + ar1_flat = ravel(ar1) + ar2_flat = ravel(ar2) + # Note: an algorithm based on searchsorted has better scaling, but in practice + # is very slow on accelerators because it relies on lax control flow. If XLA + # ever supports binary search natively, we should switch to this: + # ar2_flat = jnp.sort(ar2_flat) + # ind = jnp.searchsorted(ar2_flat, ar1_flat) + # if invert: + # return ar1_flat != ar2_flat[ind] + # else: + # return ar1_flat == ar2_flat[ind] + if invert: + return asarray((ar1_flat[:, None] != ar2_flat[None, :]).all(-1)) + else: + return asarray((ar1_flat[:, None] == ar2_flat[None, :]).any(-1)) + + +# Others +# ------ +meshgrid = _compatible_with_brainpy_array(jnp.meshgrid) +vander = _compatible_with_brainpy_array(jnp.vander) +full = _compatible_with_brainpy_array(jnp.full) +full_like = _compatible_with_brainpy_array(jnp.full_like) +eye = _compatible_with_brainpy_array(jnp.eye) +identity = _compatible_with_brainpy_array(jnp.identity) +diag = _compatible_with_brainpy_array(jnp.diag) +tri = _compatible_with_brainpy_array(jnp.tri) +tril = _compatible_with_brainpy_array(jnp.tril) +triu = _compatible_with_brainpy_array(jnp.triu) +delete = _compatible_with_brainpy_array(jnp.delete) +take_along_axis = _compatible_with_brainpy_array(jnp.take_along_axis) +block = _compatible_with_brainpy_array(jnp.block) +broadcast_arrays = _compatible_with_brainpy_array(jnp.broadcast_arrays) +broadcast_shapes = _compatible_with_brainpy_array(jnp.broadcast_shapes) +broadcast_to = _compatible_with_brainpy_array(jnp.broadcast_to) +compress = _compatible_with_brainpy_array(jnp.compress) +diag_indices = _compatible_with_brainpy_array(jnp.diag_indices) +diag_indices_from = _compatible_with_brainpy_array(jnp.diag_indices_from) +diagflat = _compatible_with_brainpy_array(jnp.diagflat) +diagonal = _compatible_with_brainpy_array(jnp.diagonal) +einsum = _compatible_with_brainpy_array(jnp.einsum) +einsum_path = _compatible_with_brainpy_array(jnp.einsum_path) +geomspace = _compatible_with_brainpy_array(jnp.geomspace) +gradient = _compatible_with_brainpy_array(jnp.gradient) +histogram2d = _compatible_with_brainpy_array(jnp.histogram2d) +histogram_bin_edges = _compatible_with_brainpy_array(jnp.histogram_bin_edges) +histogramdd = _compatible_with_brainpy_array(jnp.histogramdd) +i0 = _compatible_with_brainpy_array(jnp.i0) +indices = _compatible_with_brainpy_array(jnp.indices) +insert = _compatible_with_brainpy_array(jnp.insert) +intersect1d = _compatible_with_brainpy_array(jnp.intersect1d) +iscomplex = _compatible_with_brainpy_array(jnp.iscomplex) +isin = _compatible_with_brainpy_array(jnp.isin) +ix_ = _compatible_with_brainpy_array(jnp.ix_) +lexsort = _compatible_with_brainpy_array(jnp.lexsort) +load = _compatible_with_brainpy_array(jnp.load) +save = _compatible_with_brainpy_array(jnp.save) +savez = _compatible_with_brainpy_array(jnp.savez) +mask_indices = _compatible_with_brainpy_array(jnp.mask_indices) + + +def msort(a): + """ + Return a copy of an array sorted along the first axis. + + Parameters:: + + a : array_like + Array to be sorted. + + Returns:: + + sorted_array : ndarray + Array of the same type and shape as `a`. + + See Also:: + + sort + + Notes:: + + ``brainpy.math.msort(a)`` is equivalent to ``brainpy.math.sort(a, axis=0)``. + + """ + return sort(a, axis=0) + + +nan_to_num = _compatible_with_brainpy_array(jnp.nan_to_num) +nanargmax = _compatible_with_brainpy_array(jnp.nanargmax) +nanargmin = _compatible_with_brainpy_array(jnp.nanargmin) +pad = _compatible_with_brainpy_array(jnp.pad) +poly = _compatible_with_brainpy_array(jnp.poly) +polyadd = _compatible_with_brainpy_array(jnp.polyadd) +polyder = _compatible_with_brainpy_array(jnp.polyder) +polyfit = _compatible_with_brainpy_array(jnp.polyfit) +polyint = _compatible_with_brainpy_array(jnp.polyint) +polymul = _compatible_with_brainpy_array(jnp.polymul) +polysub = _compatible_with_brainpy_array(jnp.polysub) +polyval = _compatible_with_brainpy_array(jnp.polyval) +resize = _compatible_with_brainpy_array(jnp.resize) +rollaxis = _compatible_with_brainpy_array(jnp.rollaxis) +roots = _compatible_with_brainpy_array(jnp.roots) +rot90 = _compatible_with_brainpy_array(jnp.rot90) +setdiff1d = _compatible_with_brainpy_array(jnp.setdiff1d) +setxor1d = _compatible_with_brainpy_array(jnp.setxor1d) +tensordot = _compatible_with_brainpy_array(jnp.tensordot) +trim_zeros = _compatible_with_brainpy_array(jnp.trim_zeros) +union1d = _compatible_with_brainpy_array(jnp.union1d) +unravel_index = _compatible_with_brainpy_array(jnp.unravel_index) +unwrap = _compatible_with_brainpy_array(jnp.unwrap) + +# math funcs +# ---------- +isreal = _compatible_with_brainpy_array(jnp.isreal) +isscalar = _compatible_with_brainpy_array(jnp.isscalar) +real = _compatible_with_brainpy_array(jnp.real) +imag = _compatible_with_brainpy_array(jnp.imag) +conj = _compatible_with_brainpy_array(jnp.conj) +conjugate = _compatible_with_brainpy_array(jnp.conjugate) +ndim = _compatible_with_brainpy_array(jnp.ndim) +add = _compatible_with_brainpy_array(jnp.add) +reciprocal = _compatible_with_brainpy_array(jnp.reciprocal) +negative = _compatible_with_brainpy_array(jnp.negative) +positive = _compatible_with_brainpy_array(jnp.positive) +multiply = _compatible_with_brainpy_array(jnp.multiply) +divide = _compatible_with_brainpy_array(jnp.divide) +power = _compatible_with_brainpy_array(jnp.power) +subtract = _compatible_with_brainpy_array(jnp.subtract) +true_divide = _compatible_with_brainpy_array(jnp.true_divide) +floor_divide = _compatible_with_brainpy_array(jnp.floor_divide) +float_power = _compatible_with_brainpy_array(jnp.float_power) +fmod = _compatible_with_brainpy_array(jnp.fmod) +mod = _compatible_with_brainpy_array(jnp.mod) +divmod = _compatible_with_brainpy_array(jnp.divmod) +remainder = _compatible_with_brainpy_array(jnp.remainder) +modf = _compatible_with_brainpy_array(jnp.modf) +abs = _compatible_with_brainpy_array(jnp.abs) +absolute = _compatible_with_brainpy_array(jnp.absolute) +exp = _compatible_with_brainpy_array(jnp.exp) +exp2 = _compatible_with_brainpy_array(jnp.exp2) +expm1 = _compatible_with_brainpy_array(jnp.expm1) +log = _compatible_with_brainpy_array(jnp.log) +log10 = _compatible_with_brainpy_array(jnp.log10) +log1p = _compatible_with_brainpy_array(jnp.log1p) +log2 = _compatible_with_brainpy_array(jnp.log2) +logaddexp = _compatible_with_brainpy_array(jnp.logaddexp) +logaddexp2 = _compatible_with_brainpy_array(jnp.logaddexp2) +lcm = _compatible_with_brainpy_array(jnp.lcm) +gcd = _compatible_with_brainpy_array(jnp.gcd) +arccos = _compatible_with_brainpy_array(jnp.arccos) +arccosh = _compatible_with_brainpy_array(jnp.arccosh) +arcsin = _compatible_with_brainpy_array(jnp.arcsin) +arcsinh = _compatible_with_brainpy_array(jnp.arcsinh) +arctan = _compatible_with_brainpy_array(jnp.arctan) +arctan2 = _compatible_with_brainpy_array(jnp.arctan2) +arctanh = _compatible_with_brainpy_array(jnp.arctanh) +cos = _compatible_with_brainpy_array(jnp.cos) +cosh = _compatible_with_brainpy_array(jnp.cosh) +sin = _compatible_with_brainpy_array(jnp.sin) +sinc = _compatible_with_brainpy_array(jnp.sinc) +sinh = _compatible_with_brainpy_array(jnp.sinh) +tan = _compatible_with_brainpy_array(jnp.tan) +tanh = _compatible_with_brainpy_array(jnp.tanh) +deg2rad = _compatible_with_brainpy_array(jnp.deg2rad) +rad2deg = _compatible_with_brainpy_array(jnp.rad2deg) +degrees = _compatible_with_brainpy_array(jnp.degrees) +radians = _compatible_with_brainpy_array(jnp.radians) +hypot = _compatible_with_brainpy_array(jnp.hypot) +round = _compatible_with_brainpy_array(jnp.round) +around = round +round_ = round +rint = _compatible_with_brainpy_array(jnp.rint) +floor = _compatible_with_brainpy_array(jnp.floor) +ceil = _compatible_with_brainpy_array(jnp.ceil) +trunc = _compatible_with_brainpy_array(jnp.trunc) +prod = _compatible_with_brainpy_array(jnp.prod) + +sum = _compatible_with_brainpy_array(jnp.sum) + +diff = _compatible_with_brainpy_array(jnp.diff) +median = _compatible_with_brainpy_array(jnp.median) +nancumprod = _compatible_with_brainpy_array(jnp.nancumprod) +nancumsum = _compatible_with_brainpy_array(jnp.nancumsum) +cumprod = _compatible_with_brainpy_array(jnp.cumprod) +cumproduct = cumprod +cumsum = _compatible_with_brainpy_array(jnp.cumsum) +nanprod = _compatible_with_brainpy_array(jnp.nanprod) +nansum = _compatible_with_brainpy_array(jnp.nansum) +ediff1d = _compatible_with_brainpy_array(jnp.ediff1d) +cross = _compatible_with_brainpy_array(jnp.cross) +trapz = _compatible_with_brainpy_array(jax.scipy.integrate.trapezoid) +isfinite = _compatible_with_brainpy_array(jnp.isfinite) +isinf = _compatible_with_brainpy_array(jnp.isinf) +isnan = _compatible_with_brainpy_array(jnp.isnan) +signbit = _compatible_with_brainpy_array(jnp.signbit) +nextafter = _compatible_with_brainpy_array(jnp.nextafter) +copysign = _compatible_with_brainpy_array(jnp.copysign) +ldexp = _compatible_with_brainpy_array(jnp.ldexp) +frexp = _compatible_with_brainpy_array(jnp.frexp) +convolve = _compatible_with_brainpy_array(jnp.convolve) +sqrt = _compatible_with_brainpy_array(jnp.sqrt) +cbrt = _compatible_with_brainpy_array(jnp.cbrt) +square = _compatible_with_brainpy_array(jnp.square) +fabs = _compatible_with_brainpy_array(jnp.fabs) +sign = _compatible_with_brainpy_array(jnp.sign) +heaviside = _compatible_with_brainpy_array(jnp.heaviside) +maximum = _compatible_with_brainpy_array(jnp.maximum) +minimum = _compatible_with_brainpy_array(jnp.minimum) +fmax = _compatible_with_brainpy_array(jnp.fmax) +fmin = _compatible_with_brainpy_array(jnp.fmin) +interp = _compatible_with_brainpy_array(jnp.interp) +clip = _compatible_with_brainpy_array(jnp.clip) +angle = _compatible_with_brainpy_array(jnp.angle) +bitwise_not = _compatible_with_brainpy_array(jnp.bitwise_not) +invert = _compatible_with_brainpy_array(jnp.invert) +bitwise_and = _compatible_with_brainpy_array(jnp.bitwise_and) +bitwise_or = _compatible_with_brainpy_array(jnp.bitwise_or) +bitwise_xor = _compatible_with_brainpy_array(jnp.bitwise_xor) +left_shift = _compatible_with_brainpy_array(jnp.left_shift) +right_shift = _compatible_with_brainpy_array(jnp.right_shift) +equal = _compatible_with_brainpy_array(jnp.equal) +not_equal = _compatible_with_brainpy_array(jnp.not_equal) +greater = _compatible_with_brainpy_array(jnp.greater) +greater_equal = _compatible_with_brainpy_array(jnp.greater_equal) +less = _compatible_with_brainpy_array(jnp.less) +less_equal = _compatible_with_brainpy_array(jnp.less_equal) +array_equal = _compatible_with_brainpy_array(jnp.array_equal) +isclose = _compatible_with_brainpy_array(jnp.isclose) +allclose = _compatible_with_brainpy_array(jnp.allclose) +logical_not = _compatible_with_brainpy_array(jnp.logical_not) +logical_and = _compatible_with_brainpy_array(jnp.logical_and) +logical_or = _compatible_with_brainpy_array(jnp.logical_or) +logical_xor = _compatible_with_brainpy_array(jnp.logical_xor) +all = _compatible_with_brainpy_array(jnp.all) +any = _compatible_with_brainpy_array(jnp.any) + +alltrue = all +sometrue = any + + +def shape(a): + """ + Return the shape of an array. + + Parameters:: + + a : array_like + Input array. + + Returns:: + + shape : tuple of ints + The elements of the shape tuple give the lengths of the + corresponding array dimensions. + + See Also:: + + len : ``len(a)`` is equivalent to ``np.shape(a)[0]`` for N-D arrays with + ``N>=1``. + ndarray.shape : Equivalent array method. + + Examples:: + + >>> import brainpy + >>> brainpy.math.shape(brainpy.math.eye(3)) + (3, 3) + >>> brainpy.math.shape([[1, 3]]) + (1, 2) + >>> brainpy.math.shape([0]) + (1,) + >>> brainpy.math.shape(0) + () + + """ + if isinstance(a, (Array, jax.Array, np.ndarray)): + return a.shape + else: + return np.shape(a) + + +def size(a, axis=None): + """ + Return the number of elements along a given axis. + + Parameters:: + + a : array_like + Input data. + axis : int, optional + Axis along which the elements are counted. By default, give + the total number of elements. + + Returns:: + + element_count : int + Number of elements along the specified axis. + + See Also:: + + shape : dimensions of array + Array.shape : dimensions of array + Array.size : number of elements in array + + Examples:: + + >>> import brainpy + >>> a = brainpy.math.array([[1,2,3], [4,5,6]]) + >>> brainpy.math.size(a) + 6 + >>> brainpy.math.size(a, 1) + 3 + >>> brainpy.math.size(a, 0) + 2 + """ + if isinstance(a, (Array, jax.Array, np.ndarray)): + if axis is None: + return a.size + else: + return a.shape[axis] + else: + return np.size(a, axis=axis) + + +reshape = _compatible_with_brainpy_array(jnp.reshape) +ravel = _compatible_with_brainpy_array(jnp.ravel) +moveaxis = _compatible_with_brainpy_array(jnp.moveaxis) +transpose = _compatible_with_brainpy_array(jnp.transpose) +swapaxes = _compatible_with_brainpy_array(jnp.swapaxes) +concatenate = _compatible_with_brainpy_array(jnp.concatenate) +stack = _compatible_with_brainpy_array(jnp.stack) +vstack = _compatible_with_brainpy_array(jnp.vstack) +product = prod +row_stack = vstack +hstack = _compatible_with_brainpy_array(jnp.hstack) +dstack = _compatible_with_brainpy_array(jnp.dstack) +column_stack = _compatible_with_brainpy_array(jnp.column_stack) +split = _compatible_with_brainpy_array(jnp.split) +dsplit = _compatible_with_brainpy_array(jnp.dsplit) +hsplit = _compatible_with_brainpy_array(jnp.hsplit) +vsplit = _compatible_with_brainpy_array(jnp.vsplit) +tile = _compatible_with_brainpy_array(jnp.tile) +repeat = _compatible_with_brainpy_array(jnp.repeat) +unique = _compatible_with_brainpy_array(jnp.unique) +append = _compatible_with_brainpy_array(jnp.append) +flip = _compatible_with_brainpy_array(jnp.flip) +fliplr = _compatible_with_brainpy_array(jnp.fliplr) +flipud = _compatible_with_brainpy_array(jnp.flipud) +roll = _compatible_with_brainpy_array(jnp.roll) +atleast_1d = _compatible_with_brainpy_array(jnp.atleast_1d) +atleast_2d = _compatible_with_brainpy_array(jnp.atleast_2d) +atleast_3d = _compatible_with_brainpy_array(jnp.atleast_3d) +expand_dims = _compatible_with_brainpy_array(jnp.expand_dims) +squeeze = _compatible_with_brainpy_array(jnp.squeeze) +sort = _compatible_with_brainpy_array(jnp.sort) +argsort = _compatible_with_brainpy_array(jnp.argsort) +argmax = _compatible_with_brainpy_array(jnp.argmax) +argmin = _compatible_with_brainpy_array(jnp.argmin) +argwhere = _compatible_with_brainpy_array(jnp.argwhere) +nonzero = _compatible_with_brainpy_array(jnp.nonzero) +flatnonzero = _compatible_with_brainpy_array(jnp.flatnonzero) +where = _compatible_with_brainpy_array(jnp.where) +searchsorted = _compatible_with_brainpy_array(jnp.searchsorted) +extract = _compatible_with_brainpy_array(jnp.extract) +count_nonzero = _compatible_with_brainpy_array(jnp.count_nonzero) +max = _compatible_with_brainpy_array(jnp.max) + +min = _compatible_with_brainpy_array(jnp.min) + +amax = max +amin = min +apply_along_axis = _compatible_with_brainpy_array(jnp.apply_along_axis) +apply_over_axes = _compatible_with_brainpy_array(jnp.apply_over_axes) +array_equiv = _compatible_with_brainpy_array(jnp.array_equiv) +array_repr = _compatible_with_brainpy_array(jnp.array_repr) +array_str = _compatible_with_brainpy_array(jnp.array_str) +array_split = _compatible_with_brainpy_array(jnp.array_split) + +# indexing funcs +# -------------- + +tril_indices = jnp.tril_indices +triu_indices = jnp.triu_indices +tril_indices_from = _compatible_with_brainpy_array(jnp.tril_indices_from) +triu_indices_from = _compatible_with_brainpy_array(jnp.triu_indices_from) +take = _compatible_with_brainpy_array(jnp.take) +select = _compatible_with_brainpy_array(jnp.select) +nanmin = _compatible_with_brainpy_array(jnp.nanmin) +nanmax = _compatible_with_brainpy_array(jnp.nanmax) +ptp = _compatible_with_brainpy_array(jnp.ptp) +percentile = _compatible_with_brainpy_array(jnp.percentile) +nanpercentile = _compatible_with_brainpy_array(jnp.nanpercentile) +quantile = _compatible_with_brainpy_array(jnp.quantile) +nanquantile = _compatible_with_brainpy_array(jnp.nanquantile) +average = _compatible_with_brainpy_array(jnp.average) +mean = _compatible_with_brainpy_array(jnp.mean) +std = _compatible_with_brainpy_array(jnp.std) +var = _compatible_with_brainpy_array(jnp.var) +nanmedian = _compatible_with_brainpy_array(jnp.nanmedian) +nanmean = _compatible_with_brainpy_array(jnp.nanmean) +nanstd = _compatible_with_brainpy_array(jnp.nanstd) +nanvar = _compatible_with_brainpy_array(jnp.nanvar) +corrcoef = _compatible_with_brainpy_array(jnp.corrcoef) +correlate = _compatible_with_brainpy_array(jnp.correlate) +cov = _compatible_with_brainpy_array(jnp.cov) +histogram = _compatible_with_brainpy_array(jnp.histogram) +bincount = _compatible_with_brainpy_array(jnp.bincount) +digitize = _compatible_with_brainpy_array(jnp.digitize) +bartlett = _compatible_with_brainpy_array(jnp.bartlett) +blackman = _compatible_with_brainpy_array(jnp.blackman) +hamming = _compatible_with_brainpy_array(jnp.hamming) +hanning = _compatible_with_brainpy_array(jnp.hanning) +kaiser = _compatible_with_brainpy_array(jnp.kaiser) + +# constants +# --------- + +e = jnp.e +pi = jnp.pi +inf = jnp.inf + +# linear algebra +# -------------- + +dot = _compatible_with_brainpy_array(jnp.dot) +vdot = _compatible_with_brainpy_array(jnp.vdot) +inner = _compatible_with_brainpy_array(jnp.inner) +outer = _compatible_with_brainpy_array(jnp.outer) +kron = _compatible_with_brainpy_array(jnp.kron) +matmul = _compatible_with_brainpy_array(jnp.matmul) +trace = _compatible_with_brainpy_array(jnp.trace) + +dtype = jnp.dtype +finfo = jnp.finfo +iinfo = jnp.iinfo + +can_cast = _compatible_with_brainpy_array(jnp.can_cast) +choose = _compatible_with_brainpy_array(jnp.choose) +copy = _compatible_with_brainpy_array(jnp.copy) +frombuffer = _compatible_with_brainpy_array(jnp.frombuffer) +fromfile = _compatible_with_brainpy_array(jnp.fromfile) +fromfunction = _compatible_with_brainpy_array(jnp.fromfunction) +fromiter = _compatible_with_brainpy_array(jnp.fromiter) +fromstring = _compatible_with_brainpy_array(jnp.fromstring) +get_printoptions = np.get_printoptions +iscomplexobj = _compatible_with_brainpy_array(jnp.iscomplexobj) +isneginf = _compatible_with_brainpy_array(jnp.isneginf) +isposinf = _compatible_with_brainpy_array(jnp.isposinf) +isrealobj = _compatible_with_brainpy_array(jnp.isrealobj) +issubdtype = jnp.issubdtype +issubsctype = jnp.issubdtype +iterable = _compatible_with_brainpy_array(jnp.iterable) +packbits = _compatible_with_brainpy_array(jnp.packbits) +piecewise = _compatible_with_brainpy_array(jnp.piecewise) +printoptions = np.printoptions +set_printoptions = np.set_printoptions +promote_types = _compatible_with_brainpy_array(jnp.promote_types) +ravel_multi_index = _compatible_with_brainpy_array(jnp.ravel_multi_index) +result_type = _compatible_with_brainpy_array(jnp.result_type) +sort_complex = _compatible_with_brainpy_array(jnp.sort_complex) +unpackbits = _compatible_with_brainpy_array(jnp.unpackbits) + + +# Unique APIs +# ----------- + + +def asscalar(a): + return a.item() + + +array_type = [[np.half, np.single, np.double, np.longdouble], + [None, np.csingle, np.cdouble, np.clongdouble]] +array_precision = {np.half: 0, + np.single: 1, + np.double: 2, + np.longdouble: 3, + np.csingle: 1, + np.cdouble: 2, + np.clongdouble: 3} + + +def common_type(*arrays): + is_complex = False + precision = 0 + for a in arrays: + t = a.dtype.type + if iscomplexobj(a): + is_complex = True + if issubclass(t, jnp.integer): + p = 2 # array_precision[_nx.double] + else: + p = array_precision.get(t, None) + if p is None: + raise TypeError("can't get common type for non-numeric array") + precision = _max(precision, p) + if is_complex: + return array_type[1][precision] + else: + return array_type[0][precision] + + +genfromtxt = lambda *args, **kwargs: asarray(np.genfromtxt(*args, **kwargs)) +loadtxt = lambda *args, **kwargs: asarray(np.loadtxt(*args, **kwargs)) +info = np.info + + +def place(arr, mask, vals): + if not isinstance(arr, Array): + raise ValueError(f'Must be an instance of brainpy Array, but we got {type(arr)}') + arr[mask] = vals + + +polydiv = _compatible_with_brainpy_array(jnp.polydiv) + + +def put(a, ind, v): + if not isinstance(a, Array): + raise ValueError(f'Must be an instance of brainpy Array, but we got {type(a)}') + a[ind] = v + + +def putmask(a, mask, values): + if not isinstance(a, Array): + raise ValueError(f'Must be an instance of brainpy Array, but we got {type(a)}') + if a.shape != values.shape: + raise ValueError('Only support the shapes of "a" and "values" are consistent.') + a[mask] = values + + +def safe_eval(source): + return tree_map(Array, np.safe_eval(source)) + + +def savetxt(fname, X, fmt='%.18e', delimiter=' ', newline='\n', header='', + footer='', comments='# ', encoding=None): + X = as_numpy(X) + np.savetxt(fname, X, fmt=fmt, delimiter=delimiter, newline=newline, header=header, + footer=footer, comments=comments, encoding=encoding) + + +def savez_compressed(file, *args, **kwds): + args = tuple([(as_numpy(a) if isinstance(a, (jnp.ndarray, Array)) else a) for a in args]) + kwds = {k: (as_numpy(v) if isinstance(v, (jnp.ndarray, Array)) else v) + for k, v in kwds.items()} + np.savez_compressed(file, *args, **kwds) + + +show_config = np.show_config +typename = np.typename + + +def copyto(dst, src): + if not isinstance(dst, Array): + raise ValueError('dst must be an instance of ArrayType.') + dst[:] = src + + +def matrix(data, dtype=None): + data = array(data, copy=True, dtype=dtype) + if data.ndim > 2: + raise ValueError(f'shape too large {data.shape} to be a matrix.') + if data.ndim != 2: + for i in range(2 - data.ndim): + data = expand_dims(data, 0) + return data + + +def asmatrix(data, dtype=None): + data = array(data, dtype=dtype) + if data.ndim > 2: + raise ValueError(f'shape too large {data.shape} to be a matrix.') + if data.ndim != 2: + for i in range(2 - data.ndim): + data = expand_dims(data, 0) + return data + + +def mat(data, dtype=None): + return asmatrix(data, dtype=dtype) diff --git a/brainpy/math/environment.py b/brainpy/math/environment.py index ca258ce2..cc8b7e55 100644 --- a/brainpy/math/environment.py +++ b/brainpy/math/environment.py @@ -1,781 +1,778 @@ -# -*- coding: utf-8 -*- -# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -import functools -import gc -import inspect -import os -import re -import sys -import warnings -from typing import Any, Callable, TypeVar, cast - -import brainstate.environ -import jax -from jax import config, numpy as jnp, devices - -from . import modes -from . import scales -from .defaults import defaults -from .object_transform import naming - -__all__ = [ - # context manage for environment setting - 'environment', - 'batching_environment', - 'training_environment', - 'set_environment', - 'set', - - # default data types - 'set_float', 'get_float', - 'set_int', 'get_int', - 'set_bool', 'get_bool', - 'set_complex', 'get_complex', - - # default numerical integration step - 'set_dt', 'get_dt', - - # default computation modes - 'set_mode', 'get_mode', - - # default membrane_scaling - 'set_membrane_scaling', 'get_membrane_scaling', - - # set jax environments - 'enable_x64', 'disable_x64', - 'set_platform', 'get_platform', - 'set_host_device_count', - - # device memory - 'clear_buffer_memory', - 'enable_gpu_memory_preallocation', - 'disable_gpu_memory_preallocation', - - # deprecated - 'ditype', - 'dftype', - -] - -# See https://mypy.readthedocs.io/en/latest/generics.html#declaring-decorators -FuncType = Callable[..., Any] -F = TypeVar('F', bound=FuncType) - - -class _DecoratorContextManager: - """Allow a context manager to be used as a decorator""" - - def __call__(self, func: F) -> F: - if inspect.isgeneratorfunction(func): - return self._wrap_generator(func) - - @functools.wraps(func) - def decorate_context(*args, **kwargs): - with self.clone(): - return func(*args, **kwargs) - - return cast(F, decorate_context) - - def _wrap_generator(self, func): - """Wrap each generator invocation with the context manager""" - - @functools.wraps(func) - def generator_context(*args, **kwargs): - gen = func(*args, **kwargs) - - # Generators are suspended and unsuspended at `yield`, hence we - # make sure the grad modes is properly set every time the execution - # flow returns into the wrapped generator and restored when it - # returns through our `yield` to our caller (see PR #49017). - try: - # Issuing `None` to a generator fires it up - with self.clone(): - response = gen.send(None) - - while True: - try: - # Forward the response to our caller and get its next request - request = yield response - except GeneratorExit: - # Inform the still active generator about its imminent closure - with self.clone(): - gen.close() - raise - except BaseException: - # Propagate the exception thrown at us by the caller - with self.clone(): - response = gen.throw(*sys.exc_info()) - else: - # Pass the last request to the generator and get its response - with self.clone(): - response = gen.send(request) - - # We let the exceptions raised above by the generator's `.throw` or - # `.send` methods bubble up to our caller, except for StopIteration - except StopIteration as e: - # The generator informed us that it is done: take whatever its - # returned value (if any) was and indicate that we're done too - # by returning it (see docs for python's return-statement). - return e.value - - return generator_context - - def __enter__(self) -> None: - raise NotImplementedError - - def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: - raise NotImplementedError - - def clone(self): - # override this method if your children class takes __init__ parameters - return self.__class__() - - -class environment(_DecoratorContextManager): - r"""Context-manager that sets a computing environment for brain dynamics computation. - - In BrainPy, there are several basic computation settings when constructing models, - including ``mode`` for controlling model computing behavior, ``dt`` for numerical - integration, ``int_`` for integer precision, and ``float_`` for floating precision. - :py:class:`~.environment`` provides a context for model construction and - computation. In this temporal environment, models are constructed with the given - ``mode``, ``dt``, ``int_``, etc., environment settings. - - For instance:: - - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> - >>> with bm.environment(mode=bm.training_mode, dt=0.1): - >>> lif1 = bp.neurons.LIF(1) - >>> - >>> with bm.environment(mode=bm.nonbatching_mode, dt=0.05, float_=bm.float64): - >>> lif2 = bp.neurons.LIF(1) - - """ - - def __init__( - self, - mode: modes.Mode = None, - membrane_scaling: scales.Scaling = None, - dt: float = None, - x64: bool = None, - complex_: type = None, - float_: type = None, - int_: type = None, - bool_: type = None, - bp_object_as_pytree: bool = None, - numpy_func_return: str = None, - ) -> None: - super().__init__() - - if dt is not None: - assert isinstance(dt, float), '"dt" must a float.' - self.old_dt = get_dt() - - if mode is not None: - assert isinstance(mode, modes.Mode), f'"mode" must a {modes.Mode}.' - self.old_mode = get_mode() - - if membrane_scaling is not None: - assert isinstance(membrane_scaling, scales.Scaling), f'"membrane_scaling" must a {scales.Scaling}.' - self.old_membrane_scaling = get_membrane_scaling() - - if x64 is not None: - assert isinstance(x64, bool), f'"x64" must be a bool.' - self.old_x64 = config.read("jax_enable_x64") - - if float_ is not None: - assert isinstance(float_, type), '"float_" must a float.' - self.old_float = get_float() - - if int_ is not None: - assert isinstance(int_, type), '"int_" must a type.' - self.old_int = get_int() - - if bool_ is not None: - assert isinstance(bool_, type), '"bool_" must a type.' - self.old_bool = get_bool() - - if complex_ is not None: - assert isinstance(complex_, type), '"complex_" must a type.' - self.old_complex = get_complex() - - if bp_object_as_pytree is not None: - assert isinstance(bp_object_as_pytree, bool), '"bp_object_as_pytree" must be a bool.' - self.old_bp_object_as_pytree = defaults.bp_object_as_pytree - - if numpy_func_return is not None: - assert isinstance(numpy_func_return, str), '"numpy_func_return" must be a string.' - assert numpy_func_return in ['bp_array', 'jax_array'], \ - f'"numpy_func_return" must be "bp_array" or "jax_array". Got {numpy_func_return}.' - self.old_numpy_func_return = defaults.numpy_func_return - - self.dt = dt - self.mode = mode - self.membrane_scaling = membrane_scaling - self.x64 = x64 - self.complex_ = complex_ - self.float_ = float_ - self.int_ = int_ - self.bool_ = bool_ - self.bp_object_as_pytree = bp_object_as_pytree - self.numpy_func_return = numpy_func_return - - def __enter__(self) -> 'environment': - if self.dt is not None: set_dt(self.dt) - if self.mode is not None: set_mode(self.mode) - if self.membrane_scaling is not None: set_membrane_scaling(self.membrane_scaling) - if self.x64 is not None: set_x64(self.x64) - if self.float_ is not None: set_float(self.float_) - if self.int_ is not None: set_int(self.int_) - if self.complex_ is not None: set_complex(self.complex_) - if self.bool_ is not None: set_bool(self.bool_) - if self.bp_object_as_pytree is not None: defaults.bp_object_as_pytree = self.bp_object_as_pytree - if self.numpy_func_return is not None: defaults.numpy_func_return = self.numpy_func_return - return self - - def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: - if self.dt is not None: set_dt(self.old_dt) - if self.mode is not None: set_mode(self.old_mode) - if self.membrane_scaling is not None: set_membrane_scaling(self.old_membrane_scaling) - if self.x64 is not None: set_x64(self.old_x64) - if self.int_ is not None: set_int(self.old_int) - if self.float_ is not None: set_float(self.old_float) - if self.complex_ is not None: set_complex(self.old_complex) - if self.bool_ is not None: set_bool(self.old_bool) - if self.bp_object_as_pytree is not None: defaults.bp_object_as_pytree = self.old_bp_object_as_pytree - if self.numpy_func_return is not None: defaults.numpy_func_return = self.old_numpy_func_return - - def clone(self): - return self.__class__(dt=self.dt, - mode=self.mode, - membrane_scaling=self.membrane_scaling, - x64=self.x64, - bool_=self.bool_, - complex_=self.complex_, - float_=self.float_, - int_=self.int_, - bp_object_as_pytree=self.bp_object_as_pytree, - numpy_func_return=self.numpy_func_return) - - def __eq__(self, other): - return id(self) == id(other) - - -class training_environment(environment): - """Environment with the training mode. - - This is a short-cut context setting for an environment with the training mode. - It is equivalent to:: - - >>> import brainpy.math as bm - >>> with bm.environment(mode=bm.training_mode): - >>> pass - - """ - - def __init__( - self, - dt: float = None, - x64: bool = None, - complex_: type = None, - float_: type = None, - int_: type = None, - bool_: type = None, - batch_size: int = 1, - membrane_scaling: scales.Scaling = None, - bp_object_as_pytree: bool = None, - numpy_func_return: str = None, - ): - super().__init__(dt=dt, - x64=x64, - complex_=complex_, - float_=float_, - int_=int_, - bool_=bool_, - membrane_scaling=membrane_scaling, - mode=modes.TrainingMode(batch_size), - bp_object_as_pytree=bp_object_as_pytree, - numpy_func_return=numpy_func_return) - - -class batching_environment(environment): - """Environment with the batching mode. - - This is a short-cut context setting for an environment with the batching mode. - It is equivalent to:: - - >>> import brainpy.math as bm - >>> with bm.environment(mode=bm.batching_mode): - >>> pass - - - """ - - def __init__( - self, - dt: float = None, - x64: bool = None, - complex_: type = None, - float_: type = None, - int_: type = None, - bool_: type = None, - batch_size: int = 1, - membrane_scaling: scales.Scaling = None, - bp_object_as_pytree: bool = None, - numpy_func_return: str = None, - ): - super().__init__(dt=dt, - x64=x64, - complex_=complex_, - float_=float_, - int_=int_, - bool_=bool_, - mode=modes.BatchingMode(batch_size), - membrane_scaling=membrane_scaling, - bp_object_as_pytree=bp_object_as_pytree, - numpy_func_return=numpy_func_return) - - -def set( - mode: modes.Mode = None, - membrane_scaling: scales.Scaling = None, - dt: float = None, - x64: bool = None, - complex_: type = None, - float_: type = None, - int_: type = None, - bool_: type = None, - bp_object_as_pytree: bool = None, - numpy_func_return: str = None, -): - """Set the default computation environment. - - Parameters:: - - mode: Mode - The computing mode. - membrane_scaling: Scaling - The numerical membrane_scaling. - dt: float - The numerical integration precision. - x64: bool - Enable x64 computation. - complex_: type - The complex data type. - float_ - The floating data type. - int_ - The integer data type. - bool_ - The bool data type. - bp_object_as_pytree: bool - Whether to register brainpy object as pytree. - numpy_func_return: str - The array to return in all numpy functions. Support 'bp_array' and 'jax_array'. - """ - if dt is not None: - assert isinstance(dt, float), '"dt" must a float.' - set_dt(dt) - - if mode is not None: - assert isinstance(mode, modes.Mode), f'"mode" must a {modes.Mode}.' - set_mode(mode) - - if membrane_scaling is not None: - assert isinstance(membrane_scaling, scales.Scaling), f'"membrane_scaling" must a {scales.Scaling}.' - set_membrane_scaling(membrane_scaling) - - if x64 is not None: - assert isinstance(x64, bool), f'"x64" must be a bool.' - set_x64(x64) - - if float_ is not None: - assert isinstance(float_, type), '"float_" must a float.' - set_float(float_) - - if int_ is not None: - assert isinstance(int_, type), '"int_" must a type.' - set_int(int_) - - if bool_ is not None: - assert isinstance(bool_, type), '"bool_" must a type.' - set_bool(bool_) - - if complex_ is not None: - assert isinstance(complex_, type), '"complex_" must a type.' - set_complex(complex_) - - if bp_object_as_pytree is not None: - defaults.bp_object_as_pytree = bp_object_as_pytree - - if numpy_func_return is not None: - assert numpy_func_return in ['bp_array', 'jax_array'], f'"numpy_func_return" must be "bp_array" or "jax_array".' - defaults.numpy_func_return = numpy_func_return - - -set_environment = set - - -# default dtype -# -------------------------- - - -def ditype(): - """Default int type. - - .. deprecated:: 2.3.1 - Use `brainpy.math.int_` instead. - """ - # raise errors.NoLongerSupportError('\nGet default integer data type through `ditype()` has been deprecated. \n' - # 'Use `brainpy.math.int_` instead.') - return defaults.int_ - - -def dftype(): - """Default float type. - - .. deprecated:: 2.3.1 - Use `brainpy.math.float_` instead. - """ - - # raise errors.NoLongerSupportError('\nGet default floating data type through `dftype()` has been deprecated. \n' - # 'Use `brainpy.math.float_` instead.') - return defaults.float_ - - -def set_float(dtype: type): - """Set global default float type. - - Parameters:: - - dtype: type - The float type. - """ - defaults.float_ = dtype - - -def get_float(): - """Get the default float data type. - - Returns:: - - dftype: type - The default float data type. - """ - return defaults.float_ - - -def set_int(dtype: type): - """Set global default integer type. - - Parameters:: - - dtype: type - The integer type. - """ - defaults.int_ = dtype - - -def get_int(): - """Get the default int data type. - - Returns:: - - dftype: type - The default int data type. - """ - return defaults.int_ - - -def set_bool(dtype: type): - """Set global default boolean type. - - Parameters:: - - dtype: type - The bool type. - """ - defaults.bool_ = dtype - - -def get_bool(): - """Get the default boolean data type. - - Returns:: - - dftype: type - The default bool data type. - """ - return defaults.bool_ - - -def set_complex(dtype: type): - """Set global default complex type. - - Parameters:: - - dtype: type - The complex type. - """ - defaults.complex_ = dtype - - -def get_complex(): - """Get the default complex data type. - - Returns:: - - dftype: type - The default complex data type. - """ - return defaults.complex_ - - -# numerical precision -# -------------------------- - -def set_dt(dt): - """Set the default numerical integrator precision. - - Parameters:: - - dt : float - Numerical integration precision. - """ - assert isinstance(dt, float), f'"dt" must a float, but we got {dt}' - defaults.dt = dt - - -def get_dt(): - """Get the numerical integrator precision. - - Returns:: - - dt : float - Numerical integration precision. - """ - return defaults.dt - - -def set_mode(mode: modes.Mode): - """Set the default computing mode. - - Parameters:: - - mode: Mode - The instance of :py:class:`~.Mode`. - """ - if not isinstance(mode, modes.Mode): - raise TypeError(f'Must be instance of brainpy.math.Mode. ' - f'But we got {type(mode)}: {mode}') - defaults.mode = mode - - -def get_mode() -> modes.Mode: - """Get the default computing mode. - - References:: - - mode: Mode - The default computing mode. - """ - return defaults.mode - - -def set_membrane_scaling(membrane_scaling: scales.Scaling): - """Set the default computing membrane_scaling. - - Parameters:: - - scaling: Scaling - The instance of :py:class:`~.Scaling`. - """ - if not isinstance(membrane_scaling, scales.Scaling): - raise TypeError(f'Must be instance of brainpy.math.Scaling. ' - f'But we got {type(membrane_scaling)}: {membrane_scaling}') - defaults.membrane_scaling = membrane_scaling - - -def get_membrane_scaling() -> scales.Scaling: - """Get the default computing membrane_scaling. - - Returns:: - - membrane_scaling: Scaling - The default computing membrane_scaling. - """ - return defaults.membrane_scaling - - -def enable_x64(x64=None): - if x64 is None: - x64 = True - else: - warnings.warn( - '\n' - 'Instead of "brainpy.math.enable_x64(True)", use "brainpy.math.enable_x64()". \n' - 'Instead of "brainpy.math.enable_x64(False)", use "brainpy.math.disable_x64()". \n', - DeprecationWarning - ) - if x64: - brainstate.environ.set(precision=64) - set_int(jnp.int64) - set_float(jnp.float64) - set_complex(jnp.complex128) - else: - brainstate.environ.set(precision=32) - disable_x64() - - -def disable_x64(): - config.update("jax_enable_x64", False) - set_int(jnp.int32) - set_float(jnp.float32) - set_complex(jnp.complex64) - - -def set_x64(enable: bool): - assert isinstance(enable, bool) - if enable: - enable_x64() - else: - disable_x64() - - -def set_platform(platform: str): - """ - Changes platform to CPU, GPU, or TPU. This utility only takes - effect at the beginning of your program. - """ - assert platform in ['cpu', 'gpu', 'tpu'] - config.update("jax_platform_name", platform) - - -def get_platform() -> str: - """Get the computing platform. - - Returns:: - - platform: str - Either 'cpu', 'gpu' or 'tpu'. - """ - return devices()[0].platform - - -def set_host_device_count(n): - """ - By default, XLA considers all CPU cores as one device. This utility tells XLA - that there are `n` host (CPU) devices available to use. As a consequence, this - allows parallel mapping in JAX :func:`jax.pmap` to work in CPU platform. - - .. note:: This utility only takes effect at the beginning of your program. - Under the hood, this sets the environment variable - `XLA_FLAGS=--xla_force_host_platform_device_count=[num_devices]`, where - `[num_device]` is the desired number of CPU devices `n`. - - .. warning:: Our understanding of the side effects of using the - `xla_force_host_platform_device_count` flag in XLA is incomplete. If you - observe some strange phenomenon when using this utility, please let us - know through our issue or forum page. More information is available in this - `JAX issue `_. - - :param int n: number of devices to use. - """ - xla_flags = os.getenv("XLA_FLAGS", "") - xla_flags = re.sub(r"--xla_force_host_platform_device_count=\S+", "", xla_flags).split() - os.environ["XLA_FLAGS"] = " ".join(["--xla_force_host_platform_device_count={}".format(n)] + xla_flags) - - -def clear_buffer_memory( - platform: str = None, - array: bool = True, - transform: bool = True, - compilation: bool = False, - object_name: bool = False, -): - """Clear all on-device buffers. - - This function will be very useful when you call models in a Python loop, - because it can clear all cached arrays, and clear device memory. - - .. warning:: - - This operation may cause errors when you use a deleted buffer. - Therefore, regenerate data always. - - Parameters:: - - platform: str - The device to clear its memory. - array: bool - Clear all buffer array. Default is True. - compilation: bool - Clear compilation cache. Default is False. - transform: bool - Clear transform cache. Default is True. - object_name: bool - Clear name cache. Default is True. - - """ - if jax.__version_info__ < (0, 8, 0): - from jax.lib.xla_bridge import get_backend - else: - from jax.extend.backend import get_backend - - if array: - for buf in get_backend(platform).live_buffers(): - buf.delete() - if compilation: - jax.clear_caches() - if transform: - naming.clear_stack_cache() - if object_name: - naming.clear_name_cache() - gc.collect() - - -def disable_gpu_memory_preallocation(release_memory: bool = True): - """Disable pre-allocating the GPU memory. - - This disables the preallocation behavior. JAX will instead allocate GPU memory as needed, - potentially decreasing the overall memory usage. However, this behavior is more prone to - GPU memory fragmentation, meaning a JAX program that uses most of the available GPU memory - may OOM with preallocation disabled. - - Args: - release_memory: bool. Whether we release memory during the computation. - """ - os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' - if release_memory: - os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform' - - -def enable_gpu_memory_preallocation(): - """Disable pre-allocating the GPU memory.""" - os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'true' - os.environ.pop('XLA_PYTHON_CLIENT_ALLOCATOR', None) - - -def gpu_memory_preallocation(percent: float): - """GPU memory allocation. - - If preallocation is enabled, this makes JAX preallocate ``percent`` of the total GPU memory, - instead of the default 75%. Lowering the amount preallocated can fix OOMs that occur when the JAX program starts. - """ - assert 0. <= percent < 1., f'GPU memory preallocation must be in [0., 1.]. But we got {percent}.' - os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = str(percent) +# -*- coding: utf-8 -*- +# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import functools +import gc +import inspect +import os +import re +import sys +import warnings +from typing import Any, Callable, TypeVar, cast + +import brainstate.environ +import jax +from jax import config, numpy as jnp, devices + +from . import modes +from . import scales +from .defaults import defaults +from .object_transform import naming + +__all__ = [ + # context manage for environment setting + 'environment', + 'batching_environment', + 'training_environment', + 'set_environment', + 'set', + + # default data types + 'set_float', 'get_float', + 'set_int', 'get_int', + 'set_bool', 'get_bool', + 'set_complex', 'get_complex', + + # default numerical integration step + 'set_dt', 'get_dt', + + # default computation modes + 'set_mode', 'get_mode', + + # default membrane_scaling + 'set_membrane_scaling', 'get_membrane_scaling', + + # set jax environments + 'enable_x64', 'disable_x64', + 'set_platform', 'get_platform', + 'set_host_device_count', + + # device memory + 'clear_buffer_memory', + 'enable_gpu_memory_preallocation', + 'disable_gpu_memory_preallocation', + + # deprecated + 'ditype', + 'dftype', + +] + +# See https://mypy.readthedocs.io/en/latest/generics.html#declaring-decorators +FuncType = Callable[..., Any] +F = TypeVar('F', bound=FuncType) + + +class _DecoratorContextManager: + """Allow a context manager to be used as a decorator""" + + def __call__(self, func: F) -> F: + if inspect.isgeneratorfunction(func): + return self._wrap_generator(func) + + @functools.wraps(func) + def decorate_context(*args, **kwargs): + with self.clone(): + return func(*args, **kwargs) + + return cast(F, decorate_context) + + def _wrap_generator(self, func): + """Wrap each generator invocation with the context manager""" + + @functools.wraps(func) + def generator_context(*args, **kwargs): + gen = func(*args, **kwargs) + + # Generators are suspended and unsuspended at `yield`, hence we + # make sure the grad modes is properly set every time the execution + # flow returns into the wrapped generator and restored when it + # returns through our `yield` to our caller (see PR #49017). + try: + # Issuing `None` to a generator fires it up + with self.clone(): + response = gen.send(None) + + while True: + try: + # Forward the response to our caller and get its next request + request = yield response + except GeneratorExit: + # Inform the still active generator about its imminent closure + with self.clone(): + gen.close() + raise + except BaseException: + # Propagate the exception thrown at us by the caller + with self.clone(): + response = gen.throw(*sys.exc_info()) + else: + # Pass the last request to the generator and get its response + with self.clone(): + response = gen.send(request) + + # We let the exceptions raised above by the generator's `.throw` or + # `.send` methods bubble up to our caller, except for StopIteration + except StopIteration as e: + # The generator informed us that it is done: take whatever its + # returned value (if any) was and indicate that we're done too + # by returning it (see docs for python's return-statement). + return e.value + + return generator_context + + def __enter__(self) -> None: + raise NotImplementedError + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + raise NotImplementedError + + def clone(self): + # override this method if your children class takes __init__ parameters + return self.__class__() + + +class environment(_DecoratorContextManager): + r"""Context-manager that sets a computing environment for brain dynamics computation. + + In BrainPy, there are several basic computation settings when constructing models, + including ``mode`` for controlling model computing behavior, ``dt`` for numerical + integration, ``int_`` for integer precision, and ``float_`` for floating precision. + :py:class:`~.environment`` provides a context for model construction and + computation. In this temporal environment, models are constructed with the given + ``mode``, ``dt``, ``int_``, etc., environment settings. + + For instance:: + + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> + >>> with bm.environment(mode=bm.training_mode, dt=0.1): + >>> lif1 = bp.neurons.LIF(1) + >>> + >>> with bm.environment(mode=bm.nonbatching_mode, dt=0.05, float_=bm.float64): + >>> lif2 = bp.neurons.LIF(1) + + """ + + def __init__( + self, + mode: modes.Mode = None, + membrane_scaling: scales.Scaling = None, + dt: float = None, + x64: bool = None, + complex_: type = None, + float_: type = None, + int_: type = None, + bool_: type = None, + bp_object_as_pytree: bool = None, + numpy_func_return: str = None, + ) -> None: + super().__init__() + + if dt is not None: + assert isinstance(dt, float), '"dt" must a float.' + self.old_dt = get_dt() + + if mode is not None: + assert isinstance(mode, modes.Mode), f'"mode" must a {modes.Mode}.' + self.old_mode = get_mode() + + if membrane_scaling is not None: + assert isinstance(membrane_scaling, scales.Scaling), f'"membrane_scaling" must a {scales.Scaling}.' + self.old_membrane_scaling = get_membrane_scaling() + + if x64 is not None: + assert isinstance(x64, bool), f'"x64" must be a bool.' + self.old_x64 = config.read("jax_enable_x64") + + if float_ is not None: + assert isinstance(float_, type), '"float_" must a float.' + self.old_float = get_float() + + if int_ is not None: + assert isinstance(int_, type), '"int_" must a type.' + self.old_int = get_int() + + if bool_ is not None: + assert isinstance(bool_, type), '"bool_" must a type.' + self.old_bool = get_bool() + + if complex_ is not None: + assert isinstance(complex_, type), '"complex_" must a type.' + self.old_complex = get_complex() + + if bp_object_as_pytree is not None: + assert isinstance(bp_object_as_pytree, bool), '"bp_object_as_pytree" must be a bool.' + self.old_bp_object_as_pytree = defaults.bp_object_as_pytree + + if numpy_func_return is not None: + assert isinstance(numpy_func_return, str), '"numpy_func_return" must be a string.' + assert numpy_func_return in ['bp_array', 'jax_array'], \ + f'"numpy_func_return" must be "bp_array" or "jax_array". Got {numpy_func_return}.' + self.old_numpy_func_return = defaults.numpy_func_return + + self.dt = dt + self.mode = mode + self.membrane_scaling = membrane_scaling + self.x64 = x64 + self.complex_ = complex_ + self.float_ = float_ + self.int_ = int_ + self.bool_ = bool_ + self.bp_object_as_pytree = bp_object_as_pytree + self.numpy_func_return = numpy_func_return + + def __enter__(self) -> 'environment': + if self.dt is not None: set_dt(self.dt) + if self.mode is not None: set_mode(self.mode) + if self.membrane_scaling is not None: set_membrane_scaling(self.membrane_scaling) + if self.x64 is not None: set_x64(self.x64) + if self.float_ is not None: set_float(self.float_) + if self.int_ is not None: set_int(self.int_) + if self.complex_ is not None: set_complex(self.complex_) + if self.bool_ is not None: set_bool(self.bool_) + if self.bp_object_as_pytree is not None: defaults.bp_object_as_pytree = self.bp_object_as_pytree + if self.numpy_func_return is not None: defaults.numpy_func_return = self.numpy_func_return + return self + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + if self.dt is not None: set_dt(self.old_dt) + if self.mode is not None: set_mode(self.old_mode) + if self.membrane_scaling is not None: set_membrane_scaling(self.old_membrane_scaling) + if self.x64 is not None: set_x64(self.old_x64) + if self.int_ is not None: set_int(self.old_int) + if self.float_ is not None: set_float(self.old_float) + if self.complex_ is not None: set_complex(self.old_complex) + if self.bool_ is not None: set_bool(self.old_bool) + if self.bp_object_as_pytree is not None: defaults.bp_object_as_pytree = self.old_bp_object_as_pytree + if self.numpy_func_return is not None: defaults.numpy_func_return = self.old_numpy_func_return + + def clone(self): + return self.__class__(dt=self.dt, + mode=self.mode, + membrane_scaling=self.membrane_scaling, + x64=self.x64, + bool_=self.bool_, + complex_=self.complex_, + float_=self.float_, + int_=self.int_, + bp_object_as_pytree=self.bp_object_as_pytree, + numpy_func_return=self.numpy_func_return) + + def __eq__(self, other): + return id(self) == id(other) + + +class training_environment(environment): + """Environment with the training mode. + + This is a short-cut context setting for an environment with the training mode. + It is equivalent to:: + + >>> import brainpy.math as bm + >>> with bm.environment(mode=bm.training_mode): + >>> pass + + """ + + def __init__( + self, + dt: float = None, + x64: bool = None, + complex_: type = None, + float_: type = None, + int_: type = None, + bool_: type = None, + batch_size: int = 1, + membrane_scaling: scales.Scaling = None, + bp_object_as_pytree: bool = None, + numpy_func_return: str = None, + ): + super().__init__(dt=dt, + x64=x64, + complex_=complex_, + float_=float_, + int_=int_, + bool_=bool_, + membrane_scaling=membrane_scaling, + mode=modes.TrainingMode(batch_size), + bp_object_as_pytree=bp_object_as_pytree, + numpy_func_return=numpy_func_return) + + +class batching_environment(environment): + """Environment with the batching mode. + + This is a short-cut context setting for an environment with the batching mode. + It is equivalent to:: + + >>> import brainpy.math as bm + >>> with bm.environment(mode=bm.batching_mode): + >>> pass + + + """ + + def __init__( + self, + dt: float = None, + x64: bool = None, + complex_: type = None, + float_: type = None, + int_: type = None, + bool_: type = None, + batch_size: int = 1, + membrane_scaling: scales.Scaling = None, + bp_object_as_pytree: bool = None, + numpy_func_return: str = None, + ): + super().__init__(dt=dt, + x64=x64, + complex_=complex_, + float_=float_, + int_=int_, + bool_=bool_, + mode=modes.BatchingMode(batch_size), + membrane_scaling=membrane_scaling, + bp_object_as_pytree=bp_object_as_pytree, + numpy_func_return=numpy_func_return) + + +def set( + mode: modes.Mode = None, + membrane_scaling: scales.Scaling = None, + dt: float = None, + x64: bool = None, + complex_: type = None, + float_: type = None, + int_: type = None, + bool_: type = None, + bp_object_as_pytree: bool = None, + numpy_func_return: str = None, +): + """Set the default computation environment. + + Parameters:: + + mode: Mode + The computing mode. + membrane_scaling: Scaling + The numerical membrane_scaling. + dt: float + The numerical integration precision. + x64: bool + Enable x64 computation. + complex_: type + The complex data type. + float_ + The floating data type. + int_ + The integer data type. + bool_ + The bool data type. + bp_object_as_pytree: bool + Whether to register brainpy object as pytree. + numpy_func_return: str + The array to return in all numpy functions. Support 'bp_array' and 'jax_array'. + """ + if dt is not None: + assert isinstance(dt, float), '"dt" must a float.' + set_dt(dt) + + if mode is not None: + assert isinstance(mode, modes.Mode), f'"mode" must a {modes.Mode}.' + set_mode(mode) + + if membrane_scaling is not None: + assert isinstance(membrane_scaling, scales.Scaling), f'"membrane_scaling" must a {scales.Scaling}.' + set_membrane_scaling(membrane_scaling) + + if x64 is not None: + assert isinstance(x64, bool), f'"x64" must be a bool.' + set_x64(x64) + + if float_ is not None: + assert isinstance(float_, type), '"float_" must a float.' + set_float(float_) + + if int_ is not None: + assert isinstance(int_, type), '"int_" must a type.' + set_int(int_) + + if bool_ is not None: + assert isinstance(bool_, type), '"bool_" must a type.' + set_bool(bool_) + + if complex_ is not None: + assert isinstance(complex_, type), '"complex_" must a type.' + set_complex(complex_) + + if bp_object_as_pytree is not None: + defaults.bp_object_as_pytree = bp_object_as_pytree + + if numpy_func_return is not None: + assert numpy_func_return in ['bp_array', 'jax_array'], f'"numpy_func_return" must be "bp_array" or "jax_array".' + defaults.numpy_func_return = numpy_func_return + + +set_environment = set + + +# default dtype +# -------------------------- + + +def ditype(): + """Default int type. + + .. deprecated:: 2.3.1 + Use `brainpy.math.int_` instead. + """ + # raise errors.NoLongerSupportError('\nGet default integer data type through `ditype()` has been deprecated. \n' + # 'Use `brainpy.math.int_` instead.') + return defaults.int_ + + +def dftype(): + """Default float type. + + .. deprecated:: 2.3.1 + Use `brainpy.math.float_` instead. + """ + + # raise errors.NoLongerSupportError('\nGet default floating data type through `dftype()` has been deprecated. \n' + # 'Use `brainpy.math.float_` instead.') + return defaults.float_ + + +def set_float(dtype: type): + """Set global default float type. + + Parameters:: + + dtype: type + The float type. + """ + defaults.float_ = dtype + + +def get_float(): + """Get the default float data type. + + Returns:: + + dftype: type + The default float data type. + """ + return defaults.float_ + + +def set_int(dtype: type): + """Set global default integer type. + + Parameters:: + + dtype: type + The integer type. + """ + defaults.int_ = dtype + + +def get_int(): + """Get the default int data type. + + Returns:: + + dftype: type + The default int data type. + """ + return defaults.int_ + + +def set_bool(dtype: type): + """Set global default boolean type. + + Parameters:: + + dtype: type + The bool type. + """ + defaults.bool_ = dtype + + +def get_bool(): + """Get the default boolean data type. + + Returns:: + + dftype: type + The default bool data type. + """ + return defaults.bool_ + + +def set_complex(dtype: type): + """Set global default complex type. + + Parameters:: + + dtype: type + The complex type. + """ + defaults.complex_ = dtype + + +def get_complex(): + """Get the default complex data type. + + Returns:: + + dftype: type + The default complex data type. + """ + return defaults.complex_ + + +# numerical precision +# -------------------------- + +def set_dt(dt): + """Set the default numerical integrator precision. + + Parameters:: + + dt : float + Numerical integration precision. + """ + assert isinstance(dt, float), f'"dt" must a float, but we got {dt}' + defaults.dt = dt + + +def get_dt(): + """Get the numerical integrator precision. + + Returns:: + + dt : float + Numerical integration precision. + """ + return defaults.dt + + +def set_mode(mode: modes.Mode): + """Set the default computing mode. + + Parameters:: + + mode: Mode + The instance of :py:class:`~.Mode`. + """ + if not isinstance(mode, modes.Mode): + raise TypeError(f'Must be instance of brainpy.math.Mode. ' + f'But we got {type(mode)}: {mode}') + defaults.mode = mode + + +def get_mode() -> modes.Mode: + """Get the default computing mode. + + References:: + + mode: Mode + The default computing mode. + """ + return defaults.mode + + +def set_membrane_scaling(membrane_scaling: scales.Scaling): + """Set the default computing membrane_scaling. + + Parameters:: + + scaling: Scaling + The instance of :py:class:`~.Scaling`. + """ + if not isinstance(membrane_scaling, scales.Scaling): + raise TypeError(f'Must be instance of brainpy.math.Scaling. ' + f'But we got {type(membrane_scaling)}: {membrane_scaling}') + defaults.membrane_scaling = membrane_scaling + + +def get_membrane_scaling() -> scales.Scaling: + """Get the default computing membrane_scaling. + + Returns:: + + membrane_scaling: Scaling + The default computing membrane_scaling. + """ + return defaults.membrane_scaling + + +def enable_x64(x64=None): + if x64 is None: + x64 = True + else: + warnings.warn( + '\n' + 'Instead of "brainpy.math.enable_x64(True)", use "brainpy.math.enable_x64()". \n' + 'Instead of "brainpy.math.enable_x64(False)", use "brainpy.math.disable_x64()". \n', + DeprecationWarning + ) + if x64: + brainstate.environ.set(precision=64) + set_int(jnp.int64) + set_float(jnp.float64) + set_complex(jnp.complex128) + else: + brainstate.environ.set(precision=32) + disable_x64() + + +def disable_x64(): + config.update("jax_enable_x64", False) + set_int(jnp.int32) + set_float(jnp.float32) + set_complex(jnp.complex64) + + +def set_x64(enable: bool): + assert isinstance(enable, bool) + if enable: + enable_x64() + else: + disable_x64() + + +def set_platform(platform: str): + """ + Changes platform to CPU, GPU, or TPU. This utility only takes + effect at the beginning of your program. + """ + assert platform in ['cpu', 'gpu', 'tpu'] + config.update("jax_platform_name", platform) + + +def get_platform() -> str: + """Get the computing platform. + + Returns:: + + platform: str + Either 'cpu', 'gpu' or 'tpu'. + """ + return devices()[0].platform + + +def set_host_device_count(n): + """ + By default, XLA considers all CPU cores as one device. This utility tells XLA + that there are `n` host (CPU) devices available to use. As a consequence, this + allows parallel mapping in JAX :func:`jax.pmap` to work in CPU platform. + + .. note:: This utility only takes effect at the beginning of your program. + Under the hood, this sets the environment variable + `XLA_FLAGS=--xla_force_host_platform_device_count=[num_devices]`, where + `[num_device]` is the desired number of CPU devices `n`. + + .. warning:: Our understanding of the side effects of using the + `xla_force_host_platform_device_count` flag in XLA is incomplete. If you + observe some strange phenomenon when using this utility, please let us + know through our issue or forum page. More information is available in this + `JAX issue `_. + + :param int n: number of devices to use. + """ + xla_flags = os.getenv("XLA_FLAGS", "") + xla_flags = re.sub(r"--xla_force_host_platform_device_count=\S+", "", xla_flags).split() + os.environ["XLA_FLAGS"] = " ".join(["--xla_force_host_platform_device_count={}".format(n)] + xla_flags) + + +def clear_buffer_memory( + platform: str = None, + array: bool = True, + transform: bool = True, + compilation: bool = False, + object_name: bool = False, +): + """Clear all on-device buffers. + + This function will be very useful when you call models in a Python loop, + because it can clear all cached arrays, and clear device memory. + + .. warning:: + + This operation may cause errors when you use a deleted buffer. + Therefore, regenerate data always. + + Parameters:: + + platform: str + The device to clear its memory. + array: bool + Clear all buffer array. Default is True. + compilation: bool + Clear compilation cache. Default is False. + transform: bool + Clear transform cache. Default is True. + object_name: bool + Clear name cache. Default is True. + + """ + from brainstate._compatible_import import get_backend + + if array: + for buf in get_backend(platform).live_buffers(): + buf.delete() + if compilation: + jax.clear_caches() + if transform: + naming.clear_stack_cache() + if object_name: + naming.clear_name_cache() + gc.collect() + + +def disable_gpu_memory_preallocation(release_memory: bool = True): + """Disable pre-allocating the GPU memory. + + This disables the preallocation behavior. JAX will instead allocate GPU memory as needed, + potentially decreasing the overall memory usage. However, this behavior is more prone to + GPU memory fragmentation, meaning a JAX program that uses most of the available GPU memory + may OOM with preallocation disabled. + + Args: + release_memory: bool. Whether we release memory during the computation. + """ + os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' + if release_memory: + os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform' + + +def enable_gpu_memory_preallocation(): + """Disable pre-allocating the GPU memory.""" + os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'true' + os.environ.pop('XLA_PYTHON_CLIENT_ALLOCATOR', None) + + +def gpu_memory_preallocation(percent: float): + """GPU memory allocation. + + If preallocation is enabled, this makes JAX preallocate ``percent`` of the total GPU memory, + instead of the default 75%. Lowering the amount preallocated can fix OOMs that occur when the JAX program starts. + """ + assert 0. <= percent < 1., f'GPU memory preallocation must be in [0., 1.]. But we got {percent}.' + os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = str(percent) diff --git a/brainpy/math/remove_vmap.py b/brainpy/math/remove_vmap.py index 40bea94b..cf38da22 100644 --- a/brainpy/math/remove_vmap.py +++ b/brainpy/math/remove_vmap.py @@ -1,97 +1,94 @@ -# -*- coding: utf-8 -*- -# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -import jax -import jax.numpy as jnp - -if jax.__version__ >= '0.5.0': - from jax.extend.core import Primitive -else: - from jax.core import Primitive -from jax.core import ShapedArray -from jax.interpreters import batching, mlir, xla -from .ndarray import Array - -__all__ = [ - 'remove_vmap' -] - - -def remove_vmap(x, op='any'): - if isinstance(x, Array): - x = x.value - if op == 'any': - return _any_without_vmap(x) - elif op == 'all': - return _all_without_vmap(x) - else: - raise ValueError(f'Do not support type: {op}') - - -_any_no_vmap_prim = Primitive('any_no_vmap') - - -def _any_without_vmap(x): - return _any_no_vmap_prim.bind(x) - - -def _any_without_vmap_imp(x): - return jnp.any(x) - - -def _any_without_vmap_abs(x): - return ShapedArray(shape=(), dtype=jnp.bool_) - - -def _any_without_vmap_batch(x, batch_axes): - (x,) = x - return _any_without_vmap(x), batching.not_mapped - - -_any_no_vmap_prim.def_impl(_any_without_vmap_imp) -_any_no_vmap_prim.def_abstract_eval(_any_without_vmap_abs) -batching.primitive_batchers[_any_no_vmap_prim] = _any_without_vmap_batch -if hasattr(xla, "lower_fun"): - xla.register_translation(_any_no_vmap_prim, - xla.lower_fun(_any_without_vmap_imp, multiple_results=False, new_style=True)) -mlir.register_lowering(_any_no_vmap_prim, mlir.lower_fun(_any_without_vmap_imp, multiple_results=False)) - -_all_no_vmap_prim = Primitive('all_no_vmap') - - -def _all_without_vmap(x): - return _all_no_vmap_prim.bind(x) - - -def _all_without_vmap_imp(x): - return jnp.all(x) - - -def _all_without_vmap_abs(x): - return ShapedArray(shape=(), dtype=jnp.bool_) - - -def _all_without_vmap_batch(x, batch_axes): - (x,) = x - return _all_without_vmap(x), batching.not_mapped - - -_all_no_vmap_prim.def_impl(_all_without_vmap_imp) -_all_no_vmap_prim.def_abstract_eval(_all_without_vmap_abs) -batching.primitive_batchers[_all_no_vmap_prim] = _all_without_vmap_batch -if hasattr(xla, "lower_fun"): - xla.register_translation(_all_no_vmap_prim, - xla.lower_fun(_all_without_vmap_imp, multiple_results=False, new_style=True)) -mlir.register_lowering(_all_no_vmap_prim, mlir.lower_fun(_all_without_vmap_imp, multiple_results=False)) +# -*- coding: utf-8 -*- +# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import jax +import jax.numpy as jnp + +from brainstate._compatible_import import Primitive +from jax.core import ShapedArray +from jax.interpreters import batching, mlir, xla +from .ndarray import Array + +__all__ = [ + 'remove_vmap' +] + + +def remove_vmap(x, op='any'): + if isinstance(x, Array): + x = x.value + if op == 'any': + return _any_without_vmap(x) + elif op == 'all': + return _all_without_vmap(x) + else: + raise ValueError(f'Do not support type: {op}') + + +_any_no_vmap_prim = Primitive('any_no_vmap') + + +def _any_without_vmap(x): + return _any_no_vmap_prim.bind(x) + + +def _any_without_vmap_imp(x): + return jnp.any(x) + + +def _any_without_vmap_abs(x): + return ShapedArray(shape=(), dtype=jnp.bool_) + + +def _any_without_vmap_batch(x, batch_axes): + (x,) = x + return _any_without_vmap(x), batching.not_mapped + + +_any_no_vmap_prim.def_impl(_any_without_vmap_imp) +_any_no_vmap_prim.def_abstract_eval(_any_without_vmap_abs) +batching.primitive_batchers[_any_no_vmap_prim] = _any_without_vmap_batch +if hasattr(xla, "lower_fun"): + xla.register_translation(_any_no_vmap_prim, + xla.lower_fun(_any_without_vmap_imp, multiple_results=False, new_style=True)) +mlir.register_lowering(_any_no_vmap_prim, mlir.lower_fun(_any_without_vmap_imp, multiple_results=False)) + +_all_no_vmap_prim = Primitive('all_no_vmap') + + +def _all_without_vmap(x): + return _all_no_vmap_prim.bind(x) + + +def _all_without_vmap_imp(x): + return jnp.all(x) + + +def _all_without_vmap_abs(x): + return ShapedArray(shape=(), dtype=jnp.bool_) + + +def _all_without_vmap_batch(x, batch_axes): + (x,) = x + return _all_without_vmap(x), batching.not_mapped + + +_all_no_vmap_prim.def_impl(_all_without_vmap_imp) +_all_no_vmap_prim.def_abstract_eval(_all_without_vmap_abs) +batching.primitive_batchers[_all_no_vmap_prim] = _all_without_vmap_batch +if hasattr(xla, "lower_fun"): + xla.register_translation(_all_no_vmap_prim, + xla.lower_fun(_all_without_vmap_imp, multiple_results=False, new_style=True)) +mlir.register_lowering(_all_no_vmap_prim, mlir.lower_fun(_all_without_vmap_imp, multiple_results=False)) diff --git a/brainpy/math/surrogate/_one_input_new.py b/brainpy/math/surrogate/_one_input_new.py index b3680bdb..7e016150 100644 --- a/brainpy/math/surrogate/_one_input_new.py +++ b/brainpy/math/surrogate/_one_input_new.py @@ -1,1797 +1,1793 @@ -# -*- coding: utf-8 -*- -# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -from typing import Union - -import jax -import jax.numpy as jnp -import jax.scipy as sci - -if jax.__version__ >= '0.5.0': - from jax.extend.core import Primitive -else: - from jax.core import Primitive -from jax.interpreters import batching, ad, mlir - -from brainpy.math.interoperability import as_jax -from brainpy.math.ndarray import Array as Array - -__all__ = [ - 'Surrogate', - 'Sigmoid', - 'sigmoid', - 'PiecewiseQuadratic', - 'piecewise_quadratic', - 'PiecewiseExp', - 'piecewise_exp', - 'SoftSign', - 'soft_sign', - 'Arctan', - 'arctan', - 'NonzeroSignLog', - 'nonzero_sign_log', - 'ERF', - 'erf', - 'PiecewiseLeakyRelu', - 'piecewise_leaky_relu', - 'SquarewaveFourierSeries', - 'squarewave_fourier_series', - 'S2NN', - 's2nn', - 'QPseudoSpike', - 'q_pseudo_spike', - 'LeakyRelu', - 'leaky_relu', - 'LogTailedRelu', - 'log_tailed_relu', - 'ReluGrad', - 'relu_grad', - 'GaussianGrad', - 'gaussian_grad', - 'InvSquareGrad', - 'inv_square_grad', - 'MultiGaussianGrad', - 'multi_gaussian_grad', - 'SlayerGrad', - 'slayer_grad', -] - - -def _heaviside_abstract(x, dx): - return [x] - - -def _heaviside_imp(x, dx): - z = jnp.asarray(x >= 0, dtype=x.dtype) - return [z] - - -def _heaviside_batching(args, axes): - return heaviside_p.bind(*args), [axes[0]] - - -def _heaviside_jvp(primals, tangents): - x, dx = primals - tx, tdx = tangents - primal_outs = heaviside_p.bind(x, dx) - tangent_outs = [dx * tx, ] - return primal_outs, tangent_outs - - -heaviside_p = Primitive('heaviside_p') -heaviside_p.multiple_results = True -heaviside_p.def_abstract_eval(_heaviside_abstract) -heaviside_p.def_impl(_heaviside_imp) -batching.primitive_batchers[heaviside_p] = _heaviside_batching -ad.primitive_jvps[heaviside_p] = _heaviside_jvp -mlir.register_lowering(heaviside_p, mlir.lower_fun(_heaviside_imp, multiple_results=True)) - - -def _is_bp_array(x): - return isinstance(x, Array) - - -def _as_jax(x): - return x.value if _is_bp_array(x) else x - - -class Surrogate(object): - """The base surrograte gradient function. - - To customize a surrogate gradient function, you can inherit this class and - implement the `surrogate_fun` and `surrogate_grad` methods. - - Examples:: - - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> import jax.numpy as jnp - - >>> class MySurrogate(bm.Surrogate): - ... def __init__(self, alpha=1.): - ... super().__init__() - ... self.alpha = alpha - ... - ... def surrogate_fun(self, x): - ... return jnp.sin(x) * self.alpha - ... - ... def surrogate_grad(self, x): - ... return jnp.cos(x) * self.alpha - - """ - - def __call__(self, x): - x = _as_jax(x) - dx = self.surrogate_grad(x) - return heaviside_p.bind(x, dx)[0] - - def __repr__(self): - return f'{self.__class__.__name__}()' - - def surrogate_fun(self, x) -> jax.Array: - """The surrogate function.""" - raise NotImplementedError - - def surrogate_grad(self, x) -> jax.Array: - """The gradient function of the surrogate function.""" - raise NotImplementedError - - -class Sigmoid(Surrogate): - """Spike function with the sigmoid-shaped surrogate gradient. - - See Also:: - - sigmoid - - """ - - def __init__(self, alpha: float = 4.): - super().__init__() - self.alpha = alpha - - def surrogate_fun(self, x): - return sci.special.expit(self.alpha * x) - - def surrogate_grad(self, x): - sgax = sci.special.expit(x * self.alpha) - dx = (1. - sgax) * sgax * self.alpha - return dx - - def __repr__(self): - return f'{self.__class__.__name__}(alpha={self.alpha})' - - -def sigmoid( - x: Union[jax.Array, Array], - alpha: float = 4., -): - r"""Spike function with the sigmoid-shaped surrogate gradient. - - If `origin=False`, return the forward function: - - .. math:: - - g(x) = \begin{cases} - 1, & x \geq 0 \\ - 0, & x < 0 \\ - \end{cases} - - If `origin=True`, computes the original function: - - .. math:: - - g(x) = \mathrm{sigmoid}(\alpha x) = \frac{1}{1+e^{-\alpha x}} - - Backward function: - - .. math:: - - g'(x) = \alpha * (1 - \mathrm{sigmoid} (\alpha x)) \mathrm{sigmoid} (\alpha x) - - .. plot:: - :include-source: True - - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> import matplotlib.pyplot as plt - >>> bp.visualize.get_figure(1, 1, 4, 6) - >>> xs = bm.linspace(-2, 2, 1000) - >>> for alpha in [1., 2., 4.]: - >>> grads = bm.vector_grad(bm.surrogate.sigmoid)(xs, alpha) - >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha)) - >>> plt.legend() - >>> plt.show() - - Parameters:: - - x: jax.Array, Array - The input data. - alpha: float - Parameter to control smoothness of gradient - - - Returns:: - - out: jax.Array - The spiking state. - """ - return Sigmoid(alpha=alpha)(x) - - -class PiecewiseQuadratic(Surrogate): - """Judge spiking state with a piecewise quadratic function. - - See Also:: - - piecewise_quadratic - - """ - - def __init__(self, alpha: float = 1.): - super().__init__() - self.alpha = alpha - - def surrogate_fun(self, x): - x = as_jax(x) - z = jnp.where(x < -1 / self.alpha, - 0., - jnp.where(x > 1 / self.alpha, - 1., - (-self.alpha * jnp.abs(x) / 2 + 1) * self.alpha * x + 0.5)) - return z - - def surrogate_grad(self, x): - x = as_jax(x) - dx = jnp.where(jnp.abs(x) > 1 / self.alpha, 0., (-(self.alpha * x) ** 2 + self.alpha)) - return dx - - def __repr__(self): - return f'{self.__class__.__name__}(alpha={self.alpha})' - - -def piecewise_quadratic( - x: Union[jax.Array, Array], - alpha: float = 1., -): - r"""Judge spiking state with a piecewise quadratic function [1]_ [2]_ [3]_ [4]_ [5]_. - - If `origin=False`, computes the forward function: - - .. math:: - - g(x) = \begin{cases} - 1, & x \geq 0 \\ - 0, & x < 0 \\ - \end{cases} - - If `origin=True`, computes the original function: - - .. math:: - - g(x) = - \begin{cases} - 0, & x < -\frac{1}{\alpha} \\ - -\frac{1}{2}\alpha^2|x|x + \alpha x + \frac{1}{2}, & |x| \leq \frac{1}{\alpha} \\ - 1, & x > \frac{1}{\alpha} \\ - \end{cases} - - Backward function: - - .. math:: - - g'(x) = - \begin{cases} - 0, & |x| > \frac{1}{\alpha} \\ - -\alpha^2|x|+\alpha, & |x| \leq \frac{1}{\alpha} - \end{cases} - - .. plot:: - :include-source: True - - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> import matplotlib.pyplot as plt - >>> bp.visualize.get_figure(1, 1, 4, 6) - >>> xs = bm.linspace(-3, 3, 1000) - >>> for alpha in [0.5, 1., 2., 4.]: - >>> grads = bm.vector_grad(bm.surrogate.piecewise_quadratic)(xs, alpha) - >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha)) - >>> plt.legend() - >>> plt.show() - - Parameters:: - - x: jax.Array, Array - The input data. - alpha: float - Parameter to control smoothness of gradient - - - Returns:: - - out: jax.Array - The spiking state. - - References:: - - .. [1] Esser S K, Merolla P A, Arthur J V, et al. Convolutional networks for fast, energy-efficient neuromorphic computing[J]. Proceedings of the national academy of sciences, 2016, 113(41): 11441-11446. - .. [2] Wu Y, Deng L, Li G, et al. Spatio-temporal backpropagation for training high-performance spiking neural networks[J]. Frontiers in neuroscience, 2018, 12: 331. - .. [3] Bellec G, Salaj D, Subramoney A, et al. Long short-term memory and learning-to-learn in networks of spiking neurons[C]//Proceedings of the 32nd International Conference on Neural Information Processing Systems. 2018: 795-805. - .. [4] Neftci E O, Mostafa H, Zenke F. Surrogate gradient learning in spiking neural networks: Bringing the power of gradient-based optimization to spiking neural networks[J]. IEEE Signal Processing Magazine, 2019, 36(6): 51-63. - .. [5] Panda P, Aketi S A, Roy K. Toward scalable, efficient, and accurate deep spiking neural networks with backward residual connections, stochastic softmax, and hybridization[J]. Frontiers in Neuroscience, 2020, 14. - """ - return PiecewiseQuadratic(alpha=alpha)(x) - - -class PiecewiseExp(Surrogate): - """Judge spiking state with a piecewise exponential function. - - See Also:: - - piecewise_exp - """ - - def __init__(self, alpha: float = 1.): - super().__init__() - self.alpha = alpha - - def surrogate_grad(self, x): - x = as_jax(x) - dx = (self.alpha / 2) * jnp.exp(-self.alpha * jnp.abs(x)) - return dx - - def surrogate_fun(self, x): - x = as_jax(x) - return jnp.where(x < 0, jnp.exp(self.alpha * x) / 2, 1 - jnp.exp(-self.alpha * x) / 2) - - def __repr__(self): - return f'{self.__class__.__name__}(alpha={self.alpha})' - - -def piecewise_exp( - x: Union[jax.Array, Array], - alpha: float = 1., - -): - r"""Judge spiking state with a piecewise exponential function [1]_. - - If `origin=False`, computes the forward function: - - .. math:: - - g(x) = \begin{cases} - 1, & x \geq 0 \\ - 0, & x < 0 \\ - \end{cases} - - If `origin=True`, computes the original function: - - .. math:: - - g(x) = \begin{cases} - \frac{1}{2}e^{\alpha x}, & x < 0 \\ - 1 - \frac{1}{2}e^{-\alpha x}, & x \geq 0 - \end{cases} - - Backward function: - - .. math:: - - g'(x) = \frac{\alpha}{2}e^{-\alpha |x|} - - .. plot:: - :include-source: True - - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> import matplotlib.pyplot as plt - >>> bp.visualize.get_figure(1, 1, 4, 6) - >>> xs = bm.linspace(-3, 3, 1000) - >>> for alpha in [0.5, 1., 2., 4.]: - >>> grads = bm.vector_grad(bm.surrogate.piecewise_exp)(xs, alpha) - >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha)) - >>> plt.legend() - >>> plt.show() - - Parameters:: - - x: jax.Array, Array - The input data. - alpha: float - Parameter to control smoothness of gradient - - - Returns:: - - out: jax.Array - The spiking state. - - References:: - - .. [1] Neftci E O, Mostafa H, Zenke F. Surrogate gradient learning in spiking neural networks: Bringing the power of gradient-based optimization to spiking neural networks[J]. IEEE Signal Processing Magazine, 2019, 36(6): 51-63. - """ - return PiecewiseExp(alpha=alpha)(x) - - -class SoftSign(Surrogate): - """Judge spiking state with a soft sign function. - - See Also:: - - soft_sign - """ - - def __init__(self, alpha=1.): - super().__init__() - self.alpha = alpha - - def surrogate_grad(self, x): - x = as_jax(x) - dx = self.alpha * 0.5 / (1 + jnp.abs(self.alpha * x)) ** 2 - return dx - - def surrogate_fun(self, x): - x = as_jax(x) - return x / (2 / self.alpha + 2 * jnp.abs(x)) + 0.5 - - def __repr__(self): - return f'{self.__class__.__name__}(alpha={self.alpha})' - - -def soft_sign( - x: Union[jax.Array, Array], - alpha: float = 1., - -): - r"""Judge spiking state with a soft sign function. - - If `origin=False`, computes the forward function: - - .. math:: - - g(x) = \begin{cases} - 1, & x \geq 0 \\ - 0, & x < 0 \\ - \end{cases} - - If `origin=True`, computes the original function: - - .. math:: - - g(x) = \frac{1}{2} (\frac{\alpha x}{1 + |\alpha x|} + 1) - = \frac{1}{2} (\frac{x}{\frac{1}{\alpha} + |x|} + 1) - - Backward function: - - .. math:: - - g'(x) = \frac{\alpha}{2(1 + |\alpha x|)^{2}} = \frac{1}{2\alpha(\frac{1}{\alpha} + |x|)^{2}} - - .. plot:: - :include-source: True - - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> import matplotlib.pyplot as plt - >>> bp.visualize.get_figure(1, 1, 4, 6) - >>> xs = bm.linspace(-3, 3, 1000) - >>> for alpha in [0.5, 1., 2., 4.]: - >>> grads = bm.vector_grad(bm.surrogate.soft_sign)(xs, alpha) - >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha)) - >>> plt.legend() - >>> plt.show() - - Parameters:: - - x: jax.Array, Array - The input data. - alpha: float - Parameter to control smoothness of gradient - - - Returns:: - - out: jax.Array - The spiking state. - - """ - return SoftSign(alpha=alpha)(x) - - -class Arctan(Surrogate): - """Judge spiking state with an arctan function. - - See Also:: - - arctan - """ - - def __init__(self, alpha=1.): - super().__init__() - self.alpha = alpha - - def surrogate_grad(self, x): - x = as_jax(x) - dx = self.alpha * 0.5 / (1 + (jnp.pi / 2 * self.alpha * x) ** 2) - return dx - - def surrogate_fun(self, x): - x = as_jax(x) - return jnp.arctan2(jnp.pi / 2 * self.alpha * x) / jnp.pi + 0.5 - - def __repr__(self): - return f'{self.__class__.__name__}(alpha={self.alpha})' - - -def arctan( - x: Union[jax.Array, Array], - alpha: float = 1., - -): - r"""Judge spiking state with an arctan function. - - If `origin=False`, computes the forward function: - - .. math:: - - g(x) = \begin{cases} - 1, & x \geq 0 \\ - 0, & x < 0 \\ - \end{cases} - - If `origin=True`, computes the original function: - - .. math:: - - g(x) = \frac{1}{\pi} \arctan(\frac{\pi}{2}\alpha x) + \frac{1}{2} - - Backward function: - - .. math:: - - g'(x) = \frac{\alpha}{2(1 + (\frac{\pi}{2}\alpha x)^2)} - - .. plot:: - :include-source: True - - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> import matplotlib.pyplot as plt - >>> bp.visualize.get_figure(1, 1, 4, 6) - >>> xs = bm.linspace(-3, 3, 1000) - >>> for alpha in [0.5, 1., 2., 4.]: - >>> grads = bm.vector_grad(bm.surrogate.arctan)(xs, alpha) - >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha)) - >>> plt.legend() - >>> plt.show() - - Parameters:: - - x: jax.Array, Array - The input data. - alpha: float - Parameter to control smoothness of gradient - - - Returns:: - - out: jax.Array - The spiking state. - - """ - return Arctan(alpha=alpha)(x) - - -class NonzeroSignLog(Surrogate): - """Judge spiking state with a nonzero sign log function. - - See Also:: - - nonzero_sign_log - """ - - def __init__(self, alpha=1.): - super().__init__() - self.alpha = alpha - - def surrogate_grad(self, x): - x = as_jax(x) - dx = 1. / (1 / self.alpha + jnp.abs(x)) - return dx - - def surrogate_fun(self, x): - x = as_jax(x) - return jnp.where(x < 0, -1., 1.) * jnp.log(jnp.abs(self.alpha * x) + 1) - - def __repr__(self): - return f'{self.__class__.__name__}(alpha={self.alpha})' - - -def nonzero_sign_log( - x: Union[jax.Array, Array], - alpha: float = 1., - -): - r"""Judge spiking state with a nonzero sign log function. - - If `origin=False`, computes the forward function: - - .. math:: - - g(x) = \begin{cases} - 1, & x \geq 0 \\ - 0, & x < 0 \\ - \end{cases} - - If `origin=True`, computes the original function: - - .. math:: - - g(x) = \mathrm{NonzeroSign}(x) \log (|\alpha x| + 1) - - where - - .. math:: - - \begin{split}\mathrm{NonzeroSign}(x) = - \begin{cases} - 1, & x \geq 0 \\ - -1, & x < 0 \\ - \end{cases}\end{split} - - Backward function: - - .. math:: - - g'(x) = \frac{\alpha}{1 + |\alpha x|} = \frac{1}{\frac{1}{\alpha} + |x|} - - This surrogate function has the advantage of low computation cost during the backward. - - - .. plot:: - :include-source: True - - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> import matplotlib.pyplot as plt - >>> bp.visualize.get_figure(1, 1, 4, 6) - >>> xs = bm.linspace(-3, 3, 1000) - >>> for alpha in [0.5, 1., 2., 4.]: - >>> grads = bm.vector_grad(bm.surrogate.nonzero_sign_log)(xs, alpha) - >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha)) - >>> plt.legend() - >>> plt.show() - - Parameters:: - - x: jax.Array, Array - The input data. - alpha: float - Parameter to control smoothness of gradient - - - Returns:: - - out: jax.Array - The spiking state. - - """ - return NonzeroSignLog(alpha=alpha)(x) - - -class ERF(Surrogate): - """Judge spiking state with an erf function. - - See Also:: - - erf - """ - - def __init__(self, alpha=1.): - super().__init__() - self.alpha = alpha - - def surrogate_grad(self, x): - x = as_jax(x) - dx = (self.alpha / jnp.sqrt(jnp.pi)) * jnp.exp(-jnp.power(self.alpha, 2) * x * x) - return dx - - def surrogate_fun(self, x): - x = as_jax(x) - return sci.special.erf(-self.alpha * x) * 0.5 - - def __repr__(self): - return f'{self.__class__.__name__}(alpha={self.alpha})' - - -def erf( - x: Union[jax.Array, Array], - alpha: float = 1., - -): - r"""Judge spiking state with an erf function [1]_ [2]_ [3]_. - - If `origin=False`, computes the forward function: - - .. math:: - - g(x) = \begin{cases} - 1, & x \geq 0 \\ - 0, & x < 0 \\ - \end{cases} - - If `origin=True`, computes the original function: - - .. math:: - - \begin{split} - g(x) &= \frac{1}{2}(1-\text{erf}(-\alpha x)) \\ - &= \frac{1}{2} \text{erfc}(-\alpha x) \\ - &= \frac{1}{\sqrt{\pi}}\int_{-\infty}^{\alpha x}e^{-t^2}dt - \end{split} - - Backward function: - - .. math:: - - g'(x) = \frac{\alpha}{\sqrt{\pi}}e^{-\alpha^2x^2} - - .. plot:: - :include-source: True - - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> import matplotlib.pyplot as plt - >>> bp.visualize.get_figure(1, 1, 4, 6) - >>> xs = bm.linspace(-3, 3, 1000) - >>> for alpha in [0.5, 1., 2., 4.]: - >>> grads = bm.vector_grad(bm.surrogate.nonzero_sign_log)(xs, alpha) - >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha)) - >>> plt.legend() - >>> plt.show() - - Parameters:: - - x: jax.Array, Array - The input data. - alpha: float - Parameter to control smoothness of gradient - - - Returns:: - - out: jax.Array - The spiking state. - - References:: - - .. [1] Esser S K, Appuswamy R, Merolla P, et al. Backpropagation for energy-efficient neuromorphic computing[J]. Advances in neural information processing systems, 2015, 28: 1117-1125. - .. [2] Wu Y, Deng L, Li G, et al. Spatio-temporal backpropagation for training high-performance spiking neural networks[J]. Frontiers in neuroscience, 2018, 12: 331. - .. [3] Yin B, Corradi F, Bohté S M. Effective and efficient computation with multiple-timescale spiking recurrent neural networks[C]//International Conference on Neuromorphic Systems 2020. 2020: 1-8. - - """ - return ERF(alpha=alpha)(x) - - -class PiecewiseLeakyRelu(Surrogate): - """Judge spiking state with a piecewise leaky relu function. - - See Also:: - - piecewise_leaky_relu - """ - - def __init__(self, c=0.01, w=1.): - super().__init__() - self.c = c - self.w = w - - def surrogate_fun(self, x): - x = as_jax(x) - z = jnp.where(x < -self.w, - self.c * x + self.c * self.w, - jnp.where(x > self.w, - self.c * x - self.c * self.w + 1, - 0.5 * x / self.w + 0.5)) - return z - - def surrogate_grad(self, x): - x = as_jax(x) - dx = jnp.where(jnp.abs(x) > self.w, self.c, 1 / self.w) - return dx - - def __repr__(self): - return f'{self.__class__.__name__}(c={self.c}, w={self.w})' - - -def piecewise_leaky_relu( - x: Union[jax.Array, Array], - c: float = 0.01, - w: float = 1., - -): - r"""Judge spiking state with a piecewise leaky relu function [1]_ [2]_ [3]_ [4]_ [5]_ [6]_ [7]_ [8]_. - - If `origin=False`, computes the forward function: - - .. math:: - - g(x) = \begin{cases} - 1, & x \geq 0 \\ - 0, & x < 0 \\ - \end{cases} - - If `origin=True`, computes the original function: - - .. math:: - - \begin{split}g(x) = - \begin{cases} - cx + cw, & x < -w \\ - \frac{1}{2w}x + \frac{1}{2}, & -w \leq x \leq w \\ - cx - cw + 1, & x > w \\ - \end{cases}\end{split} - - Backward function: - - .. math:: - - \begin{split}g'(x) = - \begin{cases} - \frac{1}{w}, & |x| \leq w \\ - c, & |x| > w - \end{cases}\end{split} - - .. plot:: - :include-source: True - - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> import matplotlib.pyplot as plt - >>> bp.visualize.get_figure(1, 1, 4, 6) - >>> xs = bm.linspace(-3, 3, 1000) - >>> for c in [0.01, 0.05, 0.1]: - >>> for w in [1., 2.]: - >>> grads1 = bm.vector_grad(bm.surrogate.piecewise_leaky_relu)(xs, c=c, w=w) - >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads1), label=f'x={c}, w={w}') - >>> plt.legend() - >>> plt.show() - - Parameters:: - - x: jax.Array, Array - The input data. - c: float - When :math:`|x| > w` the gradient is `c`. - w: float - When :math:`|x| <= w` the gradient is `1 / w`. - - - Returns:: - - out: jax.Array - The spiking state. - - References:: - - .. [1] Yin S, Venkataramanaiah S K, Chen G K, et al. Algorithm and hardware design of discrete-time spiking neural networks based on back propagation with binary activations[C]//2017 IEEE Biomedical Circuits and Systems Conference (BioCAS). IEEE, 2017: 1-5. - .. [2] Wu Y, Deng L, Li G, et al. Spatio-temporal backpropagation for training high-performance spiking neural networks[J]. Frontiers in neuroscience, 2018, 12: 331. - .. [3] Huh D, Sejnowski T J. Gradient descent for spiking neural networks[C]//Proceedings of the 32nd International Conference on Neural Information Processing Systems. 2018: 1440-1450. - .. [4] Wu Y, Deng L, Li G, et al. Direct training for spiking neural networks: Faster, larger, better[C]//Proceedings of the AAAI Conference on Artificial Intelligence. 2019, 33(01): 1311-1318. - .. [5] Gu P, Xiao R, Pan G, et al. STCA: Spatio-Temporal Credit Assignment with Delayed Feedback in Deep Spiking Neural Networks[C]//IJCAI. 2019: 1366-1372. - .. [6] Roy D, Chakraborty I, Roy K. Scaling deep spiking neural networks with binary stochastic activations[C]//2019 IEEE International Conference on Cognitive Computing (ICCC). IEEE, 2019: 50-58. - .. [7] Cheng X, Hao Y, Xu J, et al. LISNN: Improving Spiking Neural Networks with Lateral Interactions for Robust Object Recognition[C]//IJCAI. 1519-1525. - .. [8] Kaiser J, Mostafa H, Neftci E. Synaptic plasticity dynamics for deep continuous local learning (DECOLLE)[J]. Frontiers in Neuroscience, 2020, 14: 424. - - """ - return PiecewiseLeakyRelu(c=c, w=w)(x) - - -class SquarewaveFourierSeries(Surrogate): - """Judge spiking state with a squarewave fourier series. - - See Also:: - - squarewave_fourier_series - """ - - def __init__(self, n=2, t_period=8.): - super().__init__() - self.n = n - self.t_period = t_period - - def surrogate_grad(self, x): - x = as_jax(x) - w = jnp.pi * 2. / self.t_period - dx = jnp.cos(w * x) - for i in range(2, self.n): - dx += jnp.cos((2 * i - 1.) * w * x) - dx *= 4. / self.t_period - return dx - - def surrogate_fun(self, x): - x = as_jax(x) - w = jnp.pi * 2. / self.t_period - ret = jnp.sin(w * x) - for i in range(2, self.n): - c = (2 * i - 1.) - ret += jnp.sin(c * w * x) / c - z = 0.5 + 2. / jnp.pi * ret - return z - - def __repr__(self): - return f'{self.__class__.__name__}(n={self.n}, t_period={self.t_period})' - - -def squarewave_fourier_series( - x: Union[jax.Array, Array], - n: int = 2, - t_period: float = 8., - -): - r"""Judge spiking state with a squarewave fourier series. - - If `origin=False`, computes the forward function: - - .. math:: - - g(x) = \begin{cases} - 1, & x \geq 0 \\ - 0, & x < 0 \\ - \end{cases} - - If `origin=True`, computes the original function: - - .. math:: - - g(x) = 0.5 + \frac{1}{\pi}*\sum_{i=1}^n {\sin\left({(2i-1)*2\pi}*x/T\right) \over 2i-1 } - - Backward function: - - .. math:: - - g'(x) = \sum_{i=1}^n\frac{4\cos\left((2 * i - 1.) * 2\pi * x / T\right)}{T} - - .. plot:: - :include-source: True - - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> import matplotlib.pyplot as plt - >>> bp.visualize.get_figure(1, 1, 4, 6) - >>> xs = bm.linspace(-3, 3, 1000) - >>> for n in [2, 4, 8]: - >>> f = bm.surrogate.SquarewaveFourierSeries(n=n) - >>> grads1 = bm.vector_grad(f)(xs) - >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads1), label=f'n={n}') - >>> plt.legend() - >>> plt.show() - - Parameters:: - - x: jax.Array, Array - The input data. - n: int - t_period: float - - - Returns:: - - out: jax.Array - The spiking state. - - """ - - return SquarewaveFourierSeries(n=n, t_period=t_period)(x) - - -class S2NN(Surrogate): - """Judge spiking state with the S2NN surrogate spiking function. - - See Also:: - - s2nn - """ - - def __init__(self, alpha=4., beta=1., epsilon=1e-8): - super().__init__() - self.alpha = alpha - self.beta = beta - self.epsilon = epsilon - - def surrogate_fun(self, x): - x = as_jax(x) - z = jnp.where(x < 0., - sci.special.expit(x * self.alpha), - self.beta * jnp.log(jnp.abs((x + 1.)) + self.epsilon) + 0.5) - return z - - def surrogate_grad(self, x): - x = as_jax(x) - sg = sci.special.expit(self.alpha * x) - dx = jnp.where(x < 0., self.alpha * sg * (1. - sg), self.beta / (x + 1.)) - return dx - - def __repr__(self): - return f'{self.__class__.__name__}(alpha={self.alpha}, beta={self.beta}, epsilon={self.epsilon})' - - -def s2nn( - x: Union[jax.Array, Array], - alpha: float = 4., - beta: float = 1., - epsilon: float = 1e-8, - -): - r"""Judge spiking state with the S2NN surrogate spiking function [1]_. - - If `origin=False`, computes the forward function: - - .. math:: - - g(x) = \begin{cases} - 1, & x \geq 0 \\ - 0, & x < 0 \\ - \end{cases} - - If `origin=True`, computes the original function: - - .. math:: - - \begin{split}g(x) = \begin{cases} - \mathrm{sigmoid} (\alpha x), x < 0 \\ - \beta \ln(|x + 1|) + 0.5, x \ge 0 - \end{cases}\end{split} - - Backward function: - - .. math:: - - \begin{split}g'(x) = \begin{cases} - \alpha * (1 - \mathrm{sigmoid} (\alpha x)) \mathrm{sigmoid} (\alpha x), x < 0 \\ - \frac{\beta}{(x + 1)}, x \ge 0 - \end{cases}\end{split} - - .. plot:: - :include-source: True - - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> import matplotlib.pyplot as plt - >>> bp.visualize.get_figure(1, 1, 4, 6) - >>> xs = bm.linspace(-3, 3, 1000) - >>> grads = bm.vector_grad(bm.surrogate.s2nn)(xs, 4., 1.) - >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads), label=r'$\alpha=4, \beta=1$') - >>> grads = bm.vector_grad(bm.surrogate.s2nn)(xs, 8., 2.) - >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads), label=r'$\alpha=8, \beta=2$') - >>> plt.legend() - >>> plt.show() - - Parameters:: - - x: jax.Array, Array - The input data. - alpha: float - The param that controls the gradient when ``x < 0``. - beta: float - The param that controls the gradient when ``x >= 0`` - epsilon: float - Avoid nan - - - Returns:: - - out: jax.Array - The spiking state. - - References:: - - .. [1] Suetake, Kazuma et al. “S2NN: Time Step Reduction of Spiking Surrogate Gradients for Training Energy Efficient Single-Step Neural Networks.” ArXiv abs/2201.10879 (2022): n. pag. - - """ - return S2NN(alpha=alpha, beta=beta, epsilon=epsilon)(x) - - -class QPseudoSpike(Surrogate): - """Judge spiking state with the q-PseudoSpike surrogate function. - - See Also:: - - q_pseudo_spike - """ - - def __init__(self, alpha=2.): - super().__init__() - self.alpha = alpha - - def surrogate_grad(self, x): - x = as_jax(x) - dx = jnp.power(1 + 2 / (self.alpha + 1) * jnp.abs(x), -self.alpha) - return dx - - def surrogate_fun(self, x): - x = as_jax(x) - z = jnp.where(x < 0., - 0.5 * jnp.power(1 - 2 / (self.alpha - 1) * jnp.abs(x), 1 - self.alpha), - 1. - 0.5 * jnp.power(1 + 2 / (self.alpha - 1) * jnp.abs(x), 1 - self.alpha)) - return z - - def __repr__(self): - return f'{self.__class__.__name__}(alpha={self.alpha})' - - -def q_pseudo_spike( - x: Union[jax.Array, Array], - alpha: float = 2., - -): - r"""Judge spiking state with the q-PseudoSpike surrogate function [1]_. - - If `origin=False`, computes the forward function: - - .. math:: - - g(x) = \begin{cases} - 1, & x \geq 0 \\ - 0, & x < 0 \\ - \end{cases} - - If `origin=True`, computes the original function: - - .. math:: - - \begin{split}g(x) = - \begin{cases} - \frac{1}{2}(1-\frac{2x}{\alpha-1})^{1-\alpha}, & x < 0 \\ - 1 - \frac{1}{2}(1+\frac{2x}{\alpha-1})^{1-\alpha}, & x \geq 0. - \end{cases}\end{split} - - Backward function: - - .. math:: - - g'(x) = (1+\frac{2|x|}{\alpha-1})^{-\alpha} - - .. plot:: - :include-source: True - - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> import matplotlib.pyplot as plt - >>> xs = bm.linspace(-3, 3, 1000) - >>> bp.visualize.get_figure(1, 1, 4, 6) - >>> for alpha in [0.5, 1., 2., 4.]: - >>> grads = bm.vector_grad(bm.surrogate.q_pseudo_spike)(xs, alpha) - >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads), label=r'$\alpha=$' + str(alpha)) - >>> plt.legend() - >>> plt.show() - - Parameters:: - - x: jax.Array, Array - The input data. - alpha: float - The parameter to control tail fatness of gradient. - - - Returns:: - - out: jax.Array - The spiking state. - - References:: - - .. [1] Herranz-Celotti, Luca and Jean Rouat. “Surrogate Gradients Design.” ArXiv abs/2202.00282 (2022): n. pag. - """ - return QPseudoSpike(alpha=alpha)(x) - - -class LeakyRelu(Surrogate): - """Judge spiking state with the Leaky ReLU function. - - See Also:: - - leaky_relu - """ - - def __init__(self, alpha=0.1, beta=1.): - super().__init__() - self.alpha = alpha - self.beta = beta - - def surrogate_fun(self, x): - x = as_jax(x) - return jnp.where(x < 0., self.alpha * x, self.beta * x) - - def surrogate_grad(self, x): - x = as_jax(x) - dx = jnp.where(x < 0., self.alpha, self.beta) - return dx - - def __repr__(self): - return f'{self.__class__.__name__}(alpha={self.alpha}, beta={self.beta})' - - -def leaky_relu( - x: Union[jax.Array, Array], - alpha: float = 0.1, - beta: float = 1., - -): - r"""Judge spiking state with the Leaky ReLU function. - - If `origin=False`, computes the forward function: - - .. math:: - - g(x) = \begin{cases} - 1, & x \geq 0 \\ - 0, & x < 0 \\ - \end{cases} - - If `origin=True`, computes the original function: - - .. math:: - - \begin{split}g(x) = - \begin{cases} - \beta \cdot x, & x \geq 0 \\ - \alpha \cdot x, & x < 0 \\ - \end{cases}\end{split} - - Backward function: - - .. math:: - - \begin{split}g'(x) = - \begin{cases} - \beta, & x \geq 0 \\ - \alpha, & x < 0 \\ - \end{cases}\end{split} - - .. plot:: - :include-source: True - - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> import matplotlib.pyplot as plt - >>> xs = bm.linspace(-3, 3, 1000) - >>> bp.visualize.get_figure(1, 1, 4, 6) - >>> grads = bm.vector_grad(bm.surrogate.leaky_relu)(xs, 0., 1.) - >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads), label=r'$\alpha=0., \beta=1.$') - >>> plt.legend() - >>> plt.show() - - Parameters:: - - x: jax.Array, Array - The input data. - alpha: float - The parameter to control the gradient when :math:`x < 0`. - beta: float - The parameter to control the gradient when :math:`x >= 0`. - - - Returns:: - - out: jax.Array - The spiking state. - """ - return LeakyRelu(alpha=alpha, beta=beta)(x) - - -class LogTailedRelu(Surrogate): - """Judge spiking state with the Log-tailed ReLU function. - - See Also:: - - log_tailed_relu - """ - - def __init__(self, alpha=0.): - super().__init__() - self.alpha = alpha - - def surrogate_fun(self, x): - x = as_jax(x) - z = jnp.where(x > 1, - jnp.log(x), - jnp.where(x > 0, - x, - self.alpha * x)) - return z - - def surrogate_grad(self, x): - x = as_jax(x) - dx = jnp.where(x > 1, - 1 / x, - jnp.where(x > 0, - 1., - self.alpha)) - return dx - - def __repr__(self): - return f'{self.__class__.__name__}(alpha={self.alpha})' - - -def log_tailed_relu( - x: Union[jax.Array, Array], - alpha: float = 0., - -): - r"""Judge spiking state with the Log-tailed ReLU function [1]_. - - If `origin=False`, computes the forward function: - - .. math:: - - g(x) = \begin{cases} - 1, & x \geq 0 \\ - 0, & x < 0 \\ - \end{cases} - - If `origin=True`, computes the original function: - - .. math:: - - \begin{split}g(x) = - \begin{cases} - \alpha x, & x \leq 0 \\ - x, & 0 < x \leq 0 \\ - log(x), x > 1 \\ - \end{cases}\end{split} - - Backward function: - - .. math:: - - \begin{split}g'(x) = - \begin{cases} - \alpha, & x \leq 0 \\ - 1, & 0 < x \leq 0 \\ - \frac{1}{x}, x > 1 \\ - \end{cases}\end{split} - - .. plot:: - :include-source: True - - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> import matplotlib.pyplot as plt - >>> xs = bm.linspace(-3, 3, 1000) - >>> bp.visualize.get_figure(1, 1, 4, 6) - >>> grads = bm.vector_grad(bm.surrogate.leaky_relu)(xs, 0., 1.) - >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads), label=r'$\alpha=0., \beta=1.$') - >>> plt.legend() - >>> plt.show() - - Parameters:: - - x: jax.Array, Array - The input data. - alpha: float - The parameter to control the gradient. - - - Returns:: - - out: jax.Array - The spiking state. - - References:: - - .. [1] Cai, Zhaowei et al. “Deep Learning with Low Precision by Half-Wave Gaussian Quantization.” 2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR) (2017): 5406-5414. - """ - return LogTailedRelu(alpha=alpha)(x) - - -class ReluGrad(Surrogate): - """Judge spiking state with the ReLU gradient function. - - See Also:: - - relu_grad - """ - - def __init__(self, alpha=0.3, width=1.): - super().__init__() - self.alpha = alpha - self.width = width - - def surrogate_grad(self, x): - x = as_jax(x) - dx = jnp.maximum(self.alpha * self.width - jnp.abs(x) * self.alpha, 0) - return dx - - def __repr__(self): - return f'{self.__class__.__name__}(alpha={self.alpha}, width={self.width})' - - -def relu_grad( - x: Union[jax.Array, Array], - alpha: float = 0.3, - width: float = 1., -): - r"""Spike function with the ReLU gradient function [1]_. - - The forward function: - - .. math:: - - g(x) = \begin{cases} - 1, & x \geq 0 \\ - 0, & x < 0 \\ - \end{cases} - - Backward function: - - .. math:: - - g'(x) = \text{ReLU}(\alpha * (\mathrm{width}-|x|)) - - .. plot:: - :include-source: True - - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> import matplotlib.pyplot as plt - >>> xs = bm.linspace(-3, 3, 1000) - >>> bp.visualize.get_figure(1, 1, 4, 6) - >>> for s in [0.5, 1.]: - >>> for w in [1, 2.]: - >>> grads = bm.vector_grad(bm.surrogate.relu_grad)(xs, s, w) - >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads), label=r'$\alpha=$' + f'{s}, width={w}') - >>> plt.legend() - >>> plt.show() - - Parameters:: - - x: jax.Array, Array - The input data. - alpha: float - The parameter to control the gradient. - width: float - The parameter to control the width of the gradient. - - Returns:: - - out: jax.Array - The spiking state. - - References:: - - .. [1] Neftci, E. O., Mostafa, H. & Zenke, F. Surrogate gradient learning in spiking neural networks. IEEE Signal Process. Mag. 36, 61–63 (2019). - """ - return ReluGrad(alpha=alpha, width=width)(x) - - -class GaussianGrad(Surrogate): - """Judge spiking state with the Gaussian gradient function. - - See Also:: - - gaussian_grad - """ - - def __init__(self, sigma=0.5, alpha=0.5): - super().__init__() - self.sigma = sigma - self.alpha = alpha - - def surrogate_grad(self, x): - x = as_jax(x) - dx = jnp.exp(-(x ** 2) / 2 * jnp.power(self.sigma, 2)) / (jnp.sqrt(2 * jnp.pi) * self.sigma) - return self.alpha * dx - - def __repr__(self): - return f'{self.__class__.__name__}(alpha={self.alpha}, sigma={self.sigma})' - - -def gaussian_grad( - x: Union[jax.Array, Array], - sigma: float = 0.5, - alpha: float = 0.5, -): - r"""Spike function with the Gaussian gradient function [1]_. - - The forward function: - - .. math:: - - g(x) = \begin{cases} - 1, & x \geq 0 \\ - 0, & x < 0 \\ - \end{cases} - - Backward function: - - .. math:: - - g'(x) = \alpha * \text{gaussian}(x, 0., \sigma) - - .. plot:: - :include-source: True - - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> import matplotlib.pyplot as plt - >>> xs = bm.linspace(-3, 3, 1000) - >>> bp.visualize.get_figure(1, 1, 4, 6) - >>> for s in [0.5, 1., 2.]: - >>> grads = bm.vector_grad(bm.surrogate.gaussian_grad)(xs, s, 0.5) - >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads), label=r'$\alpha=0.5, \sigma=$' + str(s)) - >>> plt.legend() - >>> plt.show() - - Parameters:: - - x: jax.Array, Array - The input data. - sigma: float - The parameter to control the variance of gaussian distribution. - alpha: float - The parameter to control the scale of the gradient. - - Returns:: - - out: jax.Array - The spiking state. - - References:: - - .. [1] Yin, B., Corradi, F. & Bohté, S.M. Accurate and efficient time-domain classification with adaptive spiking recurrent neural networks. Nat Mach Intell 3, 905–913 (2021). - """ - return GaussianGrad(sigma=sigma, alpha=alpha)(x) - - -class MultiGaussianGrad(Surrogate): - """Judge spiking state with the multi-Gaussian gradient function. - - See Also:: - - multi_gaussian_grad - """ - - def __init__(self, h=0.15, s=6.0, sigma=0.5, scale=0.5): - super().__init__() - self.h = h - self.s = s - self.sigma = sigma - self.scale = scale - - def surrogate_grad(self, x): - x = as_jax(x) - g1 = jnp.exp(-x ** 2 / (2 * jnp.power(self.sigma, 2))) / (jnp.sqrt(2 * jnp.pi) * self.sigma) - g2 = jnp.exp(-(x - self.sigma) ** 2 / (2 * jnp.power(self.s * self.sigma, 2)) - ) / (jnp.sqrt(2 * jnp.pi) * self.s * self.sigma) - g3 = jnp.exp(-(x + self.sigma) ** 2 / (2 * jnp.power(self.s * self.sigma, 2)) - ) / (jnp.sqrt(2 * jnp.pi) * self.s * self.sigma) - dx = g1 * (1. + self.h) - g2 * self.h - g3 * self.h - return self.scale * dx - - def __repr__(self): - return f'{self.__class__.__name__}(h={self.h}, s={self.s}, sigma={self.sigma}, scale={self.scale})' - - -def multi_gaussian_grad( - x: Union[jax.Array, Array], - h: float = 0.15, - s: float = 6.0, - sigma: float = 0.5, - scale: float = 0.5, -): - r"""Spike function with the multi-Gaussian gradient function [1]_. - - The forward function: - - .. math:: - - g(x) = \begin{cases} - 1, & x \geq 0 \\ - 0, & x < 0 \\ - \end{cases} - - Backward function: - - .. math:: - - \begin{array}{l} - g'(x)=(1+h){{{\mathcal{N}}}}(x, 0, {\sigma }^{2}) - -h{{{\mathcal{N}}}}(x, \sigma,{(s\sigma )}^{2})- - h{{{\mathcal{N}}}}(x, -\sigma ,{(s\sigma )}^{2}) - \end{array} - - - .. plot:: - :include-source: True - - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> import matplotlib.pyplot as plt - >>> xs = bm.linspace(-3, 3, 1000) - >>> bp.visualize.get_figure(1, 1, 4, 6) - >>> grads = bm.vector_grad(bm.surrogate.multi_gaussian_grad)(xs) - >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads)) - >>> plt.show() - - Parameters:: - - x: jax.Array, Array - The input data. - h: float - The hyper-parameters of approximate function - s: float - The hyper-parameters of approximate function - sigma: float - The gaussian sigma. - scale: float - The gradient scale. - - Returns:: - - out: jax.Array - The spiking state. - - References:: - - .. [1] Yin, B., Corradi, F. & Bohté, S.M. Accurate and efficient time-domain classification with adaptive spiking recurrent neural networks. Nat Mach Intell 3, 905–913 (2021). - """ - return MultiGaussianGrad(h=h, s=s, sigma=sigma, scale=scale)(x) - - -class InvSquareGrad(Surrogate): - """Judge spiking state with the inverse-square surrogate gradient function. - - See Also:: - - inv_square_grad - """ - - def __init__(self, alpha=100.): - super().__init__() - self.alpha = alpha - - def surrogate_grad(self, x): - dx = 1. / (self.alpha * jnp.abs(x) + 1.0) ** 2 - return dx - - def __repr__(self): - return f'{self.__class__.__name__}(alpha={self.alpha})' - - -def inv_square_grad( - x: Union[jax.Array, Array], - alpha: float = 100. -): - r"""Spike function with the inverse-square surrogate gradient. - - Forward function: - - .. math:: - - g(x) = \begin{cases} - 1, & x \geq 0 \\ - 0, & x < 0 \\ - \end{cases} - - Backward function: - - .. math:: - - g'(x) = \frac{1}{(\alpha * |x| + 1.) ^ 2} - - - .. plot:: - :include-source: True - - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> import matplotlib.pyplot as plt - >>> xs = bm.linspace(-1, 1, 1000) - >>> for alpha in [1., 10., 100.]: - >>> grads = bm.vector_grad(bm.surrogate.inv_square_grad)(xs, alpha) - >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha)) - >>> plt.legend() - >>> plt.show() - - Parameters:: - - x: jax.Array, Array - The input data. - alpha: float - Parameter to control smoothness of gradient - - Returns:: - - out: jax.Array - The spiking state. - """ - return InvSquareGrad(alpha=alpha)(x) - - -class SlayerGrad(Surrogate): - """Judge spiking state with the slayer surrogate gradient function. - - See Also:: - - slayer_grad - """ - - def __init__(self, alpha=1.): - super().__init__() - self.alpha = alpha - - def surrogate_grad(self, x): - dx = jnp.exp(-self.alpha * jnp.abs(x)) - return dx - - def __repr__(self): - return f'{self.__class__.__name__}(alpha={self.alpha})' - - -def slayer_grad( - x: Union[jax.Array, Array], - alpha: float = 1. -): - r"""Spike function with the slayer surrogate gradient function. - - Forward function: - - .. math:: - - g(x) = \begin{cases} - 1, & x \geq 0 \\ - 0, & x < 0 \\ - \end{cases} - - Backward function: - - .. math:: - - g'(x) = \exp(-\alpha |x|) - - - .. plot:: - :include-source: True - - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> import matplotlib.pyplot as plt - >>> bp.visualize.get_figure(1, 1, 4, 6) - >>> xs = bm.linspace(-3, 3, 1000) - >>> for alpha in [0.5, 1., 2., 4.]: - >>> grads = bm.vector_grad(bm.surrogate.slayer_grad)(xs, alpha) - >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha)) - >>> plt.legend() - >>> plt.show() - - Parameters:: - - x: jax.Array, Array - The input data. - alpha: float - Parameter to control smoothness of gradient - - Returns:: - - out: jax.Array - The spiking state. - - References:: - - .. [1] Shrestha, S. B. & Orchard, G. Slayer: spike layer error reassignment in time. In Advances in Neural Information Processing Systems Vol. 31, 1412–1421 (NeurIPS, 2018). - """ - return SlayerGrad(alpha=alpha)(x) +# -*- coding: utf-8 -*- +# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from typing import Union + +import jax +import jax.numpy as jnp +import jax.scipy as sci +from brainstate._compatible_import import Primitive +from jax.interpreters import batching, ad, mlir + +from brainpy.math.interoperability import as_jax +from brainpy.math.ndarray import Array as Array + +__all__ = [ + 'Surrogate', + 'Sigmoid', + 'sigmoid', + 'PiecewiseQuadratic', + 'piecewise_quadratic', + 'PiecewiseExp', + 'piecewise_exp', + 'SoftSign', + 'soft_sign', + 'Arctan', + 'arctan', + 'NonzeroSignLog', + 'nonzero_sign_log', + 'ERF', + 'erf', + 'PiecewiseLeakyRelu', + 'piecewise_leaky_relu', + 'SquarewaveFourierSeries', + 'squarewave_fourier_series', + 'S2NN', + 's2nn', + 'QPseudoSpike', + 'q_pseudo_spike', + 'LeakyRelu', + 'leaky_relu', + 'LogTailedRelu', + 'log_tailed_relu', + 'ReluGrad', + 'relu_grad', + 'GaussianGrad', + 'gaussian_grad', + 'InvSquareGrad', + 'inv_square_grad', + 'MultiGaussianGrad', + 'multi_gaussian_grad', + 'SlayerGrad', + 'slayer_grad', +] + + +def _heaviside_abstract(x, dx): + return [x] + + +def _heaviside_imp(x, dx): + z = jnp.asarray(x >= 0, dtype=x.dtype) + return [z] + + +def _heaviside_batching(args, axes): + return heaviside_p.bind(*args), [axes[0]] + + +def _heaviside_jvp(primals, tangents): + x, dx = primals + tx, tdx = tangents + primal_outs = heaviside_p.bind(x, dx) + tangent_outs = [dx * tx, ] + return primal_outs, tangent_outs + + +heaviside_p = Primitive('heaviside_p') +heaviside_p.multiple_results = True +heaviside_p.def_abstract_eval(_heaviside_abstract) +heaviside_p.def_impl(_heaviside_imp) +batching.primitive_batchers[heaviside_p] = _heaviside_batching +ad.primitive_jvps[heaviside_p] = _heaviside_jvp +mlir.register_lowering(heaviside_p, mlir.lower_fun(_heaviside_imp, multiple_results=True)) + + +def _is_bp_array(x): + return isinstance(x, Array) + + +def _as_jax(x): + return x.value if _is_bp_array(x) else x + + +class Surrogate(object): + """The base surrograte gradient function. + + To customize a surrogate gradient function, you can inherit this class and + implement the `surrogate_fun` and `surrogate_grad` methods. + + Examples:: + + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> import jax.numpy as jnp + + >>> class MySurrogate(bm.Surrogate): + ... def __init__(self, alpha=1.): + ... super().__init__() + ... self.alpha = alpha + ... + ... def surrogate_fun(self, x): + ... return jnp.sin(x) * self.alpha + ... + ... def surrogate_grad(self, x): + ... return jnp.cos(x) * self.alpha + + """ + + def __call__(self, x): + x = _as_jax(x) + dx = self.surrogate_grad(x) + return heaviside_p.bind(x, dx)[0] + + def __repr__(self): + return f'{self.__class__.__name__}()' + + def surrogate_fun(self, x) -> jax.Array: + """The surrogate function.""" + raise NotImplementedError + + def surrogate_grad(self, x) -> jax.Array: + """The gradient function of the surrogate function.""" + raise NotImplementedError + + +class Sigmoid(Surrogate): + """Spike function with the sigmoid-shaped surrogate gradient. + + See Also:: + + sigmoid + + """ + + def __init__(self, alpha: float = 4.): + super().__init__() + self.alpha = alpha + + def surrogate_fun(self, x): + return sci.special.expit(self.alpha * x) + + def surrogate_grad(self, x): + sgax = sci.special.expit(x * self.alpha) + dx = (1. - sgax) * sgax * self.alpha + return dx + + def __repr__(self): + return f'{self.__class__.__name__}(alpha={self.alpha})' + + +def sigmoid( + x: Union[jax.Array, Array], + alpha: float = 4., +): + r"""Spike function with the sigmoid-shaped surrogate gradient. + + If `origin=False`, return the forward function: + + .. math:: + + g(x) = \begin{cases} + 1, & x \geq 0 \\ + 0, & x < 0 \\ + \end{cases} + + If `origin=True`, computes the original function: + + .. math:: + + g(x) = \mathrm{sigmoid}(\alpha x) = \frac{1}{1+e^{-\alpha x}} + + Backward function: + + .. math:: + + g'(x) = \alpha * (1 - \mathrm{sigmoid} (\alpha x)) \mathrm{sigmoid} (\alpha x) + + .. plot:: + :include-source: True + + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> import matplotlib.pyplot as plt + >>> bp.visualize.get_figure(1, 1, 4, 6) + >>> xs = bm.linspace(-2, 2, 1000) + >>> for alpha in [1., 2., 4.]: + >>> grads = bm.vector_grad(bm.surrogate.sigmoid)(xs, alpha) + >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha)) + >>> plt.legend() + >>> plt.show() + + Parameters:: + + x: jax.Array, Array + The input data. + alpha: float + Parameter to control smoothness of gradient + + + Returns:: + + out: jax.Array + The spiking state. + """ + return Sigmoid(alpha=alpha)(x) + + +class PiecewiseQuadratic(Surrogate): + """Judge spiking state with a piecewise quadratic function. + + See Also:: + + piecewise_quadratic + + """ + + def __init__(self, alpha: float = 1.): + super().__init__() + self.alpha = alpha + + def surrogate_fun(self, x): + x = as_jax(x) + z = jnp.where(x < -1 / self.alpha, + 0., + jnp.where(x > 1 / self.alpha, + 1., + (-self.alpha * jnp.abs(x) / 2 + 1) * self.alpha * x + 0.5)) + return z + + def surrogate_grad(self, x): + x = as_jax(x) + dx = jnp.where(jnp.abs(x) > 1 / self.alpha, 0., (-(self.alpha * x) ** 2 + self.alpha)) + return dx + + def __repr__(self): + return f'{self.__class__.__name__}(alpha={self.alpha})' + + +def piecewise_quadratic( + x: Union[jax.Array, Array], + alpha: float = 1., +): + r"""Judge spiking state with a piecewise quadratic function [1]_ [2]_ [3]_ [4]_ [5]_. + + If `origin=False`, computes the forward function: + + .. math:: + + g(x) = \begin{cases} + 1, & x \geq 0 \\ + 0, & x < 0 \\ + \end{cases} + + If `origin=True`, computes the original function: + + .. math:: + + g(x) = + \begin{cases} + 0, & x < -\frac{1}{\alpha} \\ + -\frac{1}{2}\alpha^2|x|x + \alpha x + \frac{1}{2}, & |x| \leq \frac{1}{\alpha} \\ + 1, & x > \frac{1}{\alpha} \\ + \end{cases} + + Backward function: + + .. math:: + + g'(x) = + \begin{cases} + 0, & |x| > \frac{1}{\alpha} \\ + -\alpha^2|x|+\alpha, & |x| \leq \frac{1}{\alpha} + \end{cases} + + .. plot:: + :include-source: True + + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> import matplotlib.pyplot as plt + >>> bp.visualize.get_figure(1, 1, 4, 6) + >>> xs = bm.linspace(-3, 3, 1000) + >>> for alpha in [0.5, 1., 2., 4.]: + >>> grads = bm.vector_grad(bm.surrogate.piecewise_quadratic)(xs, alpha) + >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha)) + >>> plt.legend() + >>> plt.show() + + Parameters:: + + x: jax.Array, Array + The input data. + alpha: float + Parameter to control smoothness of gradient + + + Returns:: + + out: jax.Array + The spiking state. + + References:: + + .. [1] Esser S K, Merolla P A, Arthur J V, et al. Convolutional networks for fast, energy-efficient neuromorphic computing[J]. Proceedings of the national academy of sciences, 2016, 113(41): 11441-11446. + .. [2] Wu Y, Deng L, Li G, et al. Spatio-temporal backpropagation for training high-performance spiking neural networks[J]. Frontiers in neuroscience, 2018, 12: 331. + .. [3] Bellec G, Salaj D, Subramoney A, et al. Long short-term memory and learning-to-learn in networks of spiking neurons[C]//Proceedings of the 32nd International Conference on Neural Information Processing Systems. 2018: 795-805. + .. [4] Neftci E O, Mostafa H, Zenke F. Surrogate gradient learning in spiking neural networks: Bringing the power of gradient-based optimization to spiking neural networks[J]. IEEE Signal Processing Magazine, 2019, 36(6): 51-63. + .. [5] Panda P, Aketi S A, Roy K. Toward scalable, efficient, and accurate deep spiking neural networks with backward residual connections, stochastic softmax, and hybridization[J]. Frontiers in Neuroscience, 2020, 14. + """ + return PiecewiseQuadratic(alpha=alpha)(x) + + +class PiecewiseExp(Surrogate): + """Judge spiking state with a piecewise exponential function. + + See Also:: + + piecewise_exp + """ + + def __init__(self, alpha: float = 1.): + super().__init__() + self.alpha = alpha + + def surrogate_grad(self, x): + x = as_jax(x) + dx = (self.alpha / 2) * jnp.exp(-self.alpha * jnp.abs(x)) + return dx + + def surrogate_fun(self, x): + x = as_jax(x) + return jnp.where(x < 0, jnp.exp(self.alpha * x) / 2, 1 - jnp.exp(-self.alpha * x) / 2) + + def __repr__(self): + return f'{self.__class__.__name__}(alpha={self.alpha})' + + +def piecewise_exp( + x: Union[jax.Array, Array], + alpha: float = 1., + +): + r"""Judge spiking state with a piecewise exponential function [1]_. + + If `origin=False`, computes the forward function: + + .. math:: + + g(x) = \begin{cases} + 1, & x \geq 0 \\ + 0, & x < 0 \\ + \end{cases} + + If `origin=True`, computes the original function: + + .. math:: + + g(x) = \begin{cases} + \frac{1}{2}e^{\alpha x}, & x < 0 \\ + 1 - \frac{1}{2}e^{-\alpha x}, & x \geq 0 + \end{cases} + + Backward function: + + .. math:: + + g'(x) = \frac{\alpha}{2}e^{-\alpha |x|} + + .. plot:: + :include-source: True + + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> import matplotlib.pyplot as plt + >>> bp.visualize.get_figure(1, 1, 4, 6) + >>> xs = bm.linspace(-3, 3, 1000) + >>> for alpha in [0.5, 1., 2., 4.]: + >>> grads = bm.vector_grad(bm.surrogate.piecewise_exp)(xs, alpha) + >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha)) + >>> plt.legend() + >>> plt.show() + + Parameters:: + + x: jax.Array, Array + The input data. + alpha: float + Parameter to control smoothness of gradient + + + Returns:: + + out: jax.Array + The spiking state. + + References:: + + .. [1] Neftci E O, Mostafa H, Zenke F. Surrogate gradient learning in spiking neural networks: Bringing the power of gradient-based optimization to spiking neural networks[J]. IEEE Signal Processing Magazine, 2019, 36(6): 51-63. + """ + return PiecewiseExp(alpha=alpha)(x) + + +class SoftSign(Surrogate): + """Judge spiking state with a soft sign function. + + See Also:: + + soft_sign + """ + + def __init__(self, alpha=1.): + super().__init__() + self.alpha = alpha + + def surrogate_grad(self, x): + x = as_jax(x) + dx = self.alpha * 0.5 / (1 + jnp.abs(self.alpha * x)) ** 2 + return dx + + def surrogate_fun(self, x): + x = as_jax(x) + return x / (2 / self.alpha + 2 * jnp.abs(x)) + 0.5 + + def __repr__(self): + return f'{self.__class__.__name__}(alpha={self.alpha})' + + +def soft_sign( + x: Union[jax.Array, Array], + alpha: float = 1., + +): + r"""Judge spiking state with a soft sign function. + + If `origin=False`, computes the forward function: + + .. math:: + + g(x) = \begin{cases} + 1, & x \geq 0 \\ + 0, & x < 0 \\ + \end{cases} + + If `origin=True`, computes the original function: + + .. math:: + + g(x) = \frac{1}{2} (\frac{\alpha x}{1 + |\alpha x|} + 1) + = \frac{1}{2} (\frac{x}{\frac{1}{\alpha} + |x|} + 1) + + Backward function: + + .. math:: + + g'(x) = \frac{\alpha}{2(1 + |\alpha x|)^{2}} = \frac{1}{2\alpha(\frac{1}{\alpha} + |x|)^{2}} + + .. plot:: + :include-source: True + + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> import matplotlib.pyplot as plt + >>> bp.visualize.get_figure(1, 1, 4, 6) + >>> xs = bm.linspace(-3, 3, 1000) + >>> for alpha in [0.5, 1., 2., 4.]: + >>> grads = bm.vector_grad(bm.surrogate.soft_sign)(xs, alpha) + >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha)) + >>> plt.legend() + >>> plt.show() + + Parameters:: + + x: jax.Array, Array + The input data. + alpha: float + Parameter to control smoothness of gradient + + + Returns:: + + out: jax.Array + The spiking state. + + """ + return SoftSign(alpha=alpha)(x) + + +class Arctan(Surrogate): + """Judge spiking state with an arctan function. + + See Also:: + + arctan + """ + + def __init__(self, alpha=1.): + super().__init__() + self.alpha = alpha + + def surrogate_grad(self, x): + x = as_jax(x) + dx = self.alpha * 0.5 / (1 + (jnp.pi / 2 * self.alpha * x) ** 2) + return dx + + def surrogate_fun(self, x): + x = as_jax(x) + return jnp.arctan2(jnp.pi / 2 * self.alpha * x) / jnp.pi + 0.5 + + def __repr__(self): + return f'{self.__class__.__name__}(alpha={self.alpha})' + + +def arctan( + x: Union[jax.Array, Array], + alpha: float = 1., + +): + r"""Judge spiking state with an arctan function. + + If `origin=False`, computes the forward function: + + .. math:: + + g(x) = \begin{cases} + 1, & x \geq 0 \\ + 0, & x < 0 \\ + \end{cases} + + If `origin=True`, computes the original function: + + .. math:: + + g(x) = \frac{1}{\pi} \arctan(\frac{\pi}{2}\alpha x) + \frac{1}{2} + + Backward function: + + .. math:: + + g'(x) = \frac{\alpha}{2(1 + (\frac{\pi}{2}\alpha x)^2)} + + .. plot:: + :include-source: True + + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> import matplotlib.pyplot as plt + >>> bp.visualize.get_figure(1, 1, 4, 6) + >>> xs = bm.linspace(-3, 3, 1000) + >>> for alpha in [0.5, 1., 2., 4.]: + >>> grads = bm.vector_grad(bm.surrogate.arctan)(xs, alpha) + >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha)) + >>> plt.legend() + >>> plt.show() + + Parameters:: + + x: jax.Array, Array + The input data. + alpha: float + Parameter to control smoothness of gradient + + + Returns:: + + out: jax.Array + The spiking state. + + """ + return Arctan(alpha=alpha)(x) + + +class NonzeroSignLog(Surrogate): + """Judge spiking state with a nonzero sign log function. + + See Also:: + + nonzero_sign_log + """ + + def __init__(self, alpha=1.): + super().__init__() + self.alpha = alpha + + def surrogate_grad(self, x): + x = as_jax(x) + dx = 1. / (1 / self.alpha + jnp.abs(x)) + return dx + + def surrogate_fun(self, x): + x = as_jax(x) + return jnp.where(x < 0, -1., 1.) * jnp.log(jnp.abs(self.alpha * x) + 1) + + def __repr__(self): + return f'{self.__class__.__name__}(alpha={self.alpha})' + + +def nonzero_sign_log( + x: Union[jax.Array, Array], + alpha: float = 1., + +): + r"""Judge spiking state with a nonzero sign log function. + + If `origin=False`, computes the forward function: + + .. math:: + + g(x) = \begin{cases} + 1, & x \geq 0 \\ + 0, & x < 0 \\ + \end{cases} + + If `origin=True`, computes the original function: + + .. math:: + + g(x) = \mathrm{NonzeroSign}(x) \log (|\alpha x| + 1) + + where + + .. math:: + + \begin{split}\mathrm{NonzeroSign}(x) = + \begin{cases} + 1, & x \geq 0 \\ + -1, & x < 0 \\ + \end{cases}\end{split} + + Backward function: + + .. math:: + + g'(x) = \frac{\alpha}{1 + |\alpha x|} = \frac{1}{\frac{1}{\alpha} + |x|} + + This surrogate function has the advantage of low computation cost during the backward. + + + .. plot:: + :include-source: True + + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> import matplotlib.pyplot as plt + >>> bp.visualize.get_figure(1, 1, 4, 6) + >>> xs = bm.linspace(-3, 3, 1000) + >>> for alpha in [0.5, 1., 2., 4.]: + >>> grads = bm.vector_grad(bm.surrogate.nonzero_sign_log)(xs, alpha) + >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha)) + >>> plt.legend() + >>> plt.show() + + Parameters:: + + x: jax.Array, Array + The input data. + alpha: float + Parameter to control smoothness of gradient + + + Returns:: + + out: jax.Array + The spiking state. + + """ + return NonzeroSignLog(alpha=alpha)(x) + + +class ERF(Surrogate): + """Judge spiking state with an erf function. + + See Also:: + + erf + """ + + def __init__(self, alpha=1.): + super().__init__() + self.alpha = alpha + + def surrogate_grad(self, x): + x = as_jax(x) + dx = (self.alpha / jnp.sqrt(jnp.pi)) * jnp.exp(-jnp.power(self.alpha, 2) * x * x) + return dx + + def surrogate_fun(self, x): + x = as_jax(x) + return sci.special.erf(-self.alpha * x) * 0.5 + + def __repr__(self): + return f'{self.__class__.__name__}(alpha={self.alpha})' + + +def erf( + x: Union[jax.Array, Array], + alpha: float = 1., + +): + r"""Judge spiking state with an erf function [1]_ [2]_ [3]_. + + If `origin=False`, computes the forward function: + + .. math:: + + g(x) = \begin{cases} + 1, & x \geq 0 \\ + 0, & x < 0 \\ + \end{cases} + + If `origin=True`, computes the original function: + + .. math:: + + \begin{split} + g(x) &= \frac{1}{2}(1-\text{erf}(-\alpha x)) \\ + &= \frac{1}{2} \text{erfc}(-\alpha x) \\ + &= \frac{1}{\sqrt{\pi}}\int_{-\infty}^{\alpha x}e^{-t^2}dt + \end{split} + + Backward function: + + .. math:: + + g'(x) = \frac{\alpha}{\sqrt{\pi}}e^{-\alpha^2x^2} + + .. plot:: + :include-source: True + + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> import matplotlib.pyplot as plt + >>> bp.visualize.get_figure(1, 1, 4, 6) + >>> xs = bm.linspace(-3, 3, 1000) + >>> for alpha in [0.5, 1., 2., 4.]: + >>> grads = bm.vector_grad(bm.surrogate.nonzero_sign_log)(xs, alpha) + >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha)) + >>> plt.legend() + >>> plt.show() + + Parameters:: + + x: jax.Array, Array + The input data. + alpha: float + Parameter to control smoothness of gradient + + + Returns:: + + out: jax.Array + The spiking state. + + References:: + + .. [1] Esser S K, Appuswamy R, Merolla P, et al. Backpropagation for energy-efficient neuromorphic computing[J]. Advances in neural information processing systems, 2015, 28: 1117-1125. + .. [2] Wu Y, Deng L, Li G, et al. Spatio-temporal backpropagation for training high-performance spiking neural networks[J]. Frontiers in neuroscience, 2018, 12: 331. + .. [3] Yin B, Corradi F, Bohté S M. Effective and efficient computation with multiple-timescale spiking recurrent neural networks[C]//International Conference on Neuromorphic Systems 2020. 2020: 1-8. + + """ + return ERF(alpha=alpha)(x) + + +class PiecewiseLeakyRelu(Surrogate): + """Judge spiking state with a piecewise leaky relu function. + + See Also:: + + piecewise_leaky_relu + """ + + def __init__(self, c=0.01, w=1.): + super().__init__() + self.c = c + self.w = w + + def surrogate_fun(self, x): + x = as_jax(x) + z = jnp.where(x < -self.w, + self.c * x + self.c * self.w, + jnp.where(x > self.w, + self.c * x - self.c * self.w + 1, + 0.5 * x / self.w + 0.5)) + return z + + def surrogate_grad(self, x): + x = as_jax(x) + dx = jnp.where(jnp.abs(x) > self.w, self.c, 1 / self.w) + return dx + + def __repr__(self): + return f'{self.__class__.__name__}(c={self.c}, w={self.w})' + + +def piecewise_leaky_relu( + x: Union[jax.Array, Array], + c: float = 0.01, + w: float = 1., + +): + r"""Judge spiking state with a piecewise leaky relu function [1]_ [2]_ [3]_ [4]_ [5]_ [6]_ [7]_ [8]_. + + If `origin=False`, computes the forward function: + + .. math:: + + g(x) = \begin{cases} + 1, & x \geq 0 \\ + 0, & x < 0 \\ + \end{cases} + + If `origin=True`, computes the original function: + + .. math:: + + \begin{split}g(x) = + \begin{cases} + cx + cw, & x < -w \\ + \frac{1}{2w}x + \frac{1}{2}, & -w \leq x \leq w \\ + cx - cw + 1, & x > w \\ + \end{cases}\end{split} + + Backward function: + + .. math:: + + \begin{split}g'(x) = + \begin{cases} + \frac{1}{w}, & |x| \leq w \\ + c, & |x| > w + \end{cases}\end{split} + + .. plot:: + :include-source: True + + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> import matplotlib.pyplot as plt + >>> bp.visualize.get_figure(1, 1, 4, 6) + >>> xs = bm.linspace(-3, 3, 1000) + >>> for c in [0.01, 0.05, 0.1]: + >>> for w in [1., 2.]: + >>> grads1 = bm.vector_grad(bm.surrogate.piecewise_leaky_relu)(xs, c=c, w=w) + >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads1), label=f'x={c}, w={w}') + >>> plt.legend() + >>> plt.show() + + Parameters:: + + x: jax.Array, Array + The input data. + c: float + When :math:`|x| > w` the gradient is `c`. + w: float + When :math:`|x| <= w` the gradient is `1 / w`. + + + Returns:: + + out: jax.Array + The spiking state. + + References:: + + .. [1] Yin S, Venkataramanaiah S K, Chen G K, et al. Algorithm and hardware design of discrete-time spiking neural networks based on back propagation with binary activations[C]//2017 IEEE Biomedical Circuits and Systems Conference (BioCAS). IEEE, 2017: 1-5. + .. [2] Wu Y, Deng L, Li G, et al. Spatio-temporal backpropagation for training high-performance spiking neural networks[J]. Frontiers in neuroscience, 2018, 12: 331. + .. [3] Huh D, Sejnowski T J. Gradient descent for spiking neural networks[C]//Proceedings of the 32nd International Conference on Neural Information Processing Systems. 2018: 1440-1450. + .. [4] Wu Y, Deng L, Li G, et al. Direct training for spiking neural networks: Faster, larger, better[C]//Proceedings of the AAAI Conference on Artificial Intelligence. 2019, 33(01): 1311-1318. + .. [5] Gu P, Xiao R, Pan G, et al. STCA: Spatio-Temporal Credit Assignment with Delayed Feedback in Deep Spiking Neural Networks[C]//IJCAI. 2019: 1366-1372. + .. [6] Roy D, Chakraborty I, Roy K. Scaling deep spiking neural networks with binary stochastic activations[C]//2019 IEEE International Conference on Cognitive Computing (ICCC). IEEE, 2019: 50-58. + .. [7] Cheng X, Hao Y, Xu J, et al. LISNN: Improving Spiking Neural Networks with Lateral Interactions for Robust Object Recognition[C]//IJCAI. 1519-1525. + .. [8] Kaiser J, Mostafa H, Neftci E. Synaptic plasticity dynamics for deep continuous local learning (DECOLLE)[J]. Frontiers in Neuroscience, 2020, 14: 424. + + """ + return PiecewiseLeakyRelu(c=c, w=w)(x) + + +class SquarewaveFourierSeries(Surrogate): + """Judge spiking state with a squarewave fourier series. + + See Also:: + + squarewave_fourier_series + """ + + def __init__(self, n=2, t_period=8.): + super().__init__() + self.n = n + self.t_period = t_period + + def surrogate_grad(self, x): + x = as_jax(x) + w = jnp.pi * 2. / self.t_period + dx = jnp.cos(w * x) + for i in range(2, self.n): + dx += jnp.cos((2 * i - 1.) * w * x) + dx *= 4. / self.t_period + return dx + + def surrogate_fun(self, x): + x = as_jax(x) + w = jnp.pi * 2. / self.t_period + ret = jnp.sin(w * x) + for i in range(2, self.n): + c = (2 * i - 1.) + ret += jnp.sin(c * w * x) / c + z = 0.5 + 2. / jnp.pi * ret + return z + + def __repr__(self): + return f'{self.__class__.__name__}(n={self.n}, t_period={self.t_period})' + + +def squarewave_fourier_series( + x: Union[jax.Array, Array], + n: int = 2, + t_period: float = 8., + +): + r"""Judge spiking state with a squarewave fourier series. + + If `origin=False`, computes the forward function: + + .. math:: + + g(x) = \begin{cases} + 1, & x \geq 0 \\ + 0, & x < 0 \\ + \end{cases} + + If `origin=True`, computes the original function: + + .. math:: + + g(x) = 0.5 + \frac{1}{\pi}*\sum_{i=1}^n {\sin\left({(2i-1)*2\pi}*x/T\right) \over 2i-1 } + + Backward function: + + .. math:: + + g'(x) = \sum_{i=1}^n\frac{4\cos\left((2 * i - 1.) * 2\pi * x / T\right)}{T} + + .. plot:: + :include-source: True + + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> import matplotlib.pyplot as plt + >>> bp.visualize.get_figure(1, 1, 4, 6) + >>> xs = bm.linspace(-3, 3, 1000) + >>> for n in [2, 4, 8]: + >>> f = bm.surrogate.SquarewaveFourierSeries(n=n) + >>> grads1 = bm.vector_grad(f)(xs) + >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads1), label=f'n={n}') + >>> plt.legend() + >>> plt.show() + + Parameters:: + + x: jax.Array, Array + The input data. + n: int + t_period: float + + + Returns:: + + out: jax.Array + The spiking state. + + """ + + return SquarewaveFourierSeries(n=n, t_period=t_period)(x) + + +class S2NN(Surrogate): + """Judge spiking state with the S2NN surrogate spiking function. + + See Also:: + + s2nn + """ + + def __init__(self, alpha=4., beta=1., epsilon=1e-8): + super().__init__() + self.alpha = alpha + self.beta = beta + self.epsilon = epsilon + + def surrogate_fun(self, x): + x = as_jax(x) + z = jnp.where(x < 0., + sci.special.expit(x * self.alpha), + self.beta * jnp.log(jnp.abs((x + 1.)) + self.epsilon) + 0.5) + return z + + def surrogate_grad(self, x): + x = as_jax(x) + sg = sci.special.expit(self.alpha * x) + dx = jnp.where(x < 0., self.alpha * sg * (1. - sg), self.beta / (x + 1.)) + return dx + + def __repr__(self): + return f'{self.__class__.__name__}(alpha={self.alpha}, beta={self.beta}, epsilon={self.epsilon})' + + +def s2nn( + x: Union[jax.Array, Array], + alpha: float = 4., + beta: float = 1., + epsilon: float = 1e-8, + +): + r"""Judge spiking state with the S2NN surrogate spiking function [1]_. + + If `origin=False`, computes the forward function: + + .. math:: + + g(x) = \begin{cases} + 1, & x \geq 0 \\ + 0, & x < 0 \\ + \end{cases} + + If `origin=True`, computes the original function: + + .. math:: + + \begin{split}g(x) = \begin{cases} + \mathrm{sigmoid} (\alpha x), x < 0 \\ + \beta \ln(|x + 1|) + 0.5, x \ge 0 + \end{cases}\end{split} + + Backward function: + + .. math:: + + \begin{split}g'(x) = \begin{cases} + \alpha * (1 - \mathrm{sigmoid} (\alpha x)) \mathrm{sigmoid} (\alpha x), x < 0 \\ + \frac{\beta}{(x + 1)}, x \ge 0 + \end{cases}\end{split} + + .. plot:: + :include-source: True + + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> import matplotlib.pyplot as plt + >>> bp.visualize.get_figure(1, 1, 4, 6) + >>> xs = bm.linspace(-3, 3, 1000) + >>> grads = bm.vector_grad(bm.surrogate.s2nn)(xs, 4., 1.) + >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads), label=r'$\alpha=4, \beta=1$') + >>> grads = bm.vector_grad(bm.surrogate.s2nn)(xs, 8., 2.) + >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads), label=r'$\alpha=8, \beta=2$') + >>> plt.legend() + >>> plt.show() + + Parameters:: + + x: jax.Array, Array + The input data. + alpha: float + The param that controls the gradient when ``x < 0``. + beta: float + The param that controls the gradient when ``x >= 0`` + epsilon: float + Avoid nan + + + Returns:: + + out: jax.Array + The spiking state. + + References:: + + .. [1] Suetake, Kazuma et al. “S2NN: Time Step Reduction of Spiking Surrogate Gradients for Training Energy Efficient Single-Step Neural Networks.” ArXiv abs/2201.10879 (2022): n. pag. + + """ + return S2NN(alpha=alpha, beta=beta, epsilon=epsilon)(x) + + +class QPseudoSpike(Surrogate): + """Judge spiking state with the q-PseudoSpike surrogate function. + + See Also:: + + q_pseudo_spike + """ + + def __init__(self, alpha=2.): + super().__init__() + self.alpha = alpha + + def surrogate_grad(self, x): + x = as_jax(x) + dx = jnp.power(1 + 2 / (self.alpha + 1) * jnp.abs(x), -self.alpha) + return dx + + def surrogate_fun(self, x): + x = as_jax(x) + z = jnp.where(x < 0., + 0.5 * jnp.power(1 - 2 / (self.alpha - 1) * jnp.abs(x), 1 - self.alpha), + 1. - 0.5 * jnp.power(1 + 2 / (self.alpha - 1) * jnp.abs(x), 1 - self.alpha)) + return z + + def __repr__(self): + return f'{self.__class__.__name__}(alpha={self.alpha})' + + +def q_pseudo_spike( + x: Union[jax.Array, Array], + alpha: float = 2., + +): + r"""Judge spiking state with the q-PseudoSpike surrogate function [1]_. + + If `origin=False`, computes the forward function: + + .. math:: + + g(x) = \begin{cases} + 1, & x \geq 0 \\ + 0, & x < 0 \\ + \end{cases} + + If `origin=True`, computes the original function: + + .. math:: + + \begin{split}g(x) = + \begin{cases} + \frac{1}{2}(1-\frac{2x}{\alpha-1})^{1-\alpha}, & x < 0 \\ + 1 - \frac{1}{2}(1+\frac{2x}{\alpha-1})^{1-\alpha}, & x \geq 0. + \end{cases}\end{split} + + Backward function: + + .. math:: + + g'(x) = (1+\frac{2|x|}{\alpha-1})^{-\alpha} + + .. plot:: + :include-source: True + + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> import matplotlib.pyplot as plt + >>> xs = bm.linspace(-3, 3, 1000) + >>> bp.visualize.get_figure(1, 1, 4, 6) + >>> for alpha in [0.5, 1., 2., 4.]: + >>> grads = bm.vector_grad(bm.surrogate.q_pseudo_spike)(xs, alpha) + >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads), label=r'$\alpha=$' + str(alpha)) + >>> plt.legend() + >>> plt.show() + + Parameters:: + + x: jax.Array, Array + The input data. + alpha: float + The parameter to control tail fatness of gradient. + + + Returns:: + + out: jax.Array + The spiking state. + + References:: + + .. [1] Herranz-Celotti, Luca and Jean Rouat. “Surrogate Gradients Design.” ArXiv abs/2202.00282 (2022): n. pag. + """ + return QPseudoSpike(alpha=alpha)(x) + + +class LeakyRelu(Surrogate): + """Judge spiking state with the Leaky ReLU function. + + See Also:: + + leaky_relu + """ + + def __init__(self, alpha=0.1, beta=1.): + super().__init__() + self.alpha = alpha + self.beta = beta + + def surrogate_fun(self, x): + x = as_jax(x) + return jnp.where(x < 0., self.alpha * x, self.beta * x) + + def surrogate_grad(self, x): + x = as_jax(x) + dx = jnp.where(x < 0., self.alpha, self.beta) + return dx + + def __repr__(self): + return f'{self.__class__.__name__}(alpha={self.alpha}, beta={self.beta})' + + +def leaky_relu( + x: Union[jax.Array, Array], + alpha: float = 0.1, + beta: float = 1., + +): + r"""Judge spiking state with the Leaky ReLU function. + + If `origin=False`, computes the forward function: + + .. math:: + + g(x) = \begin{cases} + 1, & x \geq 0 \\ + 0, & x < 0 \\ + \end{cases} + + If `origin=True`, computes the original function: + + .. math:: + + \begin{split}g(x) = + \begin{cases} + \beta \cdot x, & x \geq 0 \\ + \alpha \cdot x, & x < 0 \\ + \end{cases}\end{split} + + Backward function: + + .. math:: + + \begin{split}g'(x) = + \begin{cases} + \beta, & x \geq 0 \\ + \alpha, & x < 0 \\ + \end{cases}\end{split} + + .. plot:: + :include-source: True + + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> import matplotlib.pyplot as plt + >>> xs = bm.linspace(-3, 3, 1000) + >>> bp.visualize.get_figure(1, 1, 4, 6) + >>> grads = bm.vector_grad(bm.surrogate.leaky_relu)(xs, 0., 1.) + >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads), label=r'$\alpha=0., \beta=1.$') + >>> plt.legend() + >>> plt.show() + + Parameters:: + + x: jax.Array, Array + The input data. + alpha: float + The parameter to control the gradient when :math:`x < 0`. + beta: float + The parameter to control the gradient when :math:`x >= 0`. + + + Returns:: + + out: jax.Array + The spiking state. + """ + return LeakyRelu(alpha=alpha, beta=beta)(x) + + +class LogTailedRelu(Surrogate): + """Judge spiking state with the Log-tailed ReLU function. + + See Also:: + + log_tailed_relu + """ + + def __init__(self, alpha=0.): + super().__init__() + self.alpha = alpha + + def surrogate_fun(self, x): + x = as_jax(x) + z = jnp.where(x > 1, + jnp.log(x), + jnp.where(x > 0, + x, + self.alpha * x)) + return z + + def surrogate_grad(self, x): + x = as_jax(x) + dx = jnp.where(x > 1, + 1 / x, + jnp.where(x > 0, + 1., + self.alpha)) + return dx + + def __repr__(self): + return f'{self.__class__.__name__}(alpha={self.alpha})' + + +def log_tailed_relu( + x: Union[jax.Array, Array], + alpha: float = 0., + +): + r"""Judge spiking state with the Log-tailed ReLU function [1]_. + + If `origin=False`, computes the forward function: + + .. math:: + + g(x) = \begin{cases} + 1, & x \geq 0 \\ + 0, & x < 0 \\ + \end{cases} + + If `origin=True`, computes the original function: + + .. math:: + + \begin{split}g(x) = + \begin{cases} + \alpha x, & x \leq 0 \\ + x, & 0 < x \leq 0 \\ + log(x), x > 1 \\ + \end{cases}\end{split} + + Backward function: + + .. math:: + + \begin{split}g'(x) = + \begin{cases} + \alpha, & x \leq 0 \\ + 1, & 0 < x \leq 0 \\ + \frac{1}{x}, x > 1 \\ + \end{cases}\end{split} + + .. plot:: + :include-source: True + + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> import matplotlib.pyplot as plt + >>> xs = bm.linspace(-3, 3, 1000) + >>> bp.visualize.get_figure(1, 1, 4, 6) + >>> grads = bm.vector_grad(bm.surrogate.leaky_relu)(xs, 0., 1.) + >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads), label=r'$\alpha=0., \beta=1.$') + >>> plt.legend() + >>> plt.show() + + Parameters:: + + x: jax.Array, Array + The input data. + alpha: float + The parameter to control the gradient. + + + Returns:: + + out: jax.Array + The spiking state. + + References:: + + .. [1] Cai, Zhaowei et al. “Deep Learning with Low Precision by Half-Wave Gaussian Quantization.” 2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR) (2017): 5406-5414. + """ + return LogTailedRelu(alpha=alpha)(x) + + +class ReluGrad(Surrogate): + """Judge spiking state with the ReLU gradient function. + + See Also:: + + relu_grad + """ + + def __init__(self, alpha=0.3, width=1.): + super().__init__() + self.alpha = alpha + self.width = width + + def surrogate_grad(self, x): + x = as_jax(x) + dx = jnp.maximum(self.alpha * self.width - jnp.abs(x) * self.alpha, 0) + return dx + + def __repr__(self): + return f'{self.__class__.__name__}(alpha={self.alpha}, width={self.width})' + + +def relu_grad( + x: Union[jax.Array, Array], + alpha: float = 0.3, + width: float = 1., +): + r"""Spike function with the ReLU gradient function [1]_. + + The forward function: + + .. math:: + + g(x) = \begin{cases} + 1, & x \geq 0 \\ + 0, & x < 0 \\ + \end{cases} + + Backward function: + + .. math:: + + g'(x) = \text{ReLU}(\alpha * (\mathrm{width}-|x|)) + + .. plot:: + :include-source: True + + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> import matplotlib.pyplot as plt + >>> xs = bm.linspace(-3, 3, 1000) + >>> bp.visualize.get_figure(1, 1, 4, 6) + >>> for s in [0.5, 1.]: + >>> for w in [1, 2.]: + >>> grads = bm.vector_grad(bm.surrogate.relu_grad)(xs, s, w) + >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads), label=r'$\alpha=$' + f'{s}, width={w}') + >>> plt.legend() + >>> plt.show() + + Parameters:: + + x: jax.Array, Array + The input data. + alpha: float + The parameter to control the gradient. + width: float + The parameter to control the width of the gradient. + + Returns:: + + out: jax.Array + The spiking state. + + References:: + + .. [1] Neftci, E. O., Mostafa, H. & Zenke, F. Surrogate gradient learning in spiking neural networks. IEEE Signal Process. Mag. 36, 61–63 (2019). + """ + return ReluGrad(alpha=alpha, width=width)(x) + + +class GaussianGrad(Surrogate): + """Judge spiking state with the Gaussian gradient function. + + See Also:: + + gaussian_grad + """ + + def __init__(self, sigma=0.5, alpha=0.5): + super().__init__() + self.sigma = sigma + self.alpha = alpha + + def surrogate_grad(self, x): + x = as_jax(x) + dx = jnp.exp(-(x ** 2) / 2 * jnp.power(self.sigma, 2)) / (jnp.sqrt(2 * jnp.pi) * self.sigma) + return self.alpha * dx + + def __repr__(self): + return f'{self.__class__.__name__}(alpha={self.alpha}, sigma={self.sigma})' + + +def gaussian_grad( + x: Union[jax.Array, Array], + sigma: float = 0.5, + alpha: float = 0.5, +): + r"""Spike function with the Gaussian gradient function [1]_. + + The forward function: + + .. math:: + + g(x) = \begin{cases} + 1, & x \geq 0 \\ + 0, & x < 0 \\ + \end{cases} + + Backward function: + + .. math:: + + g'(x) = \alpha * \text{gaussian}(x, 0., \sigma) + + .. plot:: + :include-source: True + + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> import matplotlib.pyplot as plt + >>> xs = bm.linspace(-3, 3, 1000) + >>> bp.visualize.get_figure(1, 1, 4, 6) + >>> for s in [0.5, 1., 2.]: + >>> grads = bm.vector_grad(bm.surrogate.gaussian_grad)(xs, s, 0.5) + >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads), label=r'$\alpha=0.5, \sigma=$' + str(s)) + >>> plt.legend() + >>> plt.show() + + Parameters:: + + x: jax.Array, Array + The input data. + sigma: float + The parameter to control the variance of gaussian distribution. + alpha: float + The parameter to control the scale of the gradient. + + Returns:: + + out: jax.Array + The spiking state. + + References:: + + .. [1] Yin, B., Corradi, F. & Bohté, S.M. Accurate and efficient time-domain classification with adaptive spiking recurrent neural networks. Nat Mach Intell 3, 905–913 (2021). + """ + return GaussianGrad(sigma=sigma, alpha=alpha)(x) + + +class MultiGaussianGrad(Surrogate): + """Judge spiking state with the multi-Gaussian gradient function. + + See Also:: + + multi_gaussian_grad + """ + + def __init__(self, h=0.15, s=6.0, sigma=0.5, scale=0.5): + super().__init__() + self.h = h + self.s = s + self.sigma = sigma + self.scale = scale + + def surrogate_grad(self, x): + x = as_jax(x) + g1 = jnp.exp(-x ** 2 / (2 * jnp.power(self.sigma, 2))) / (jnp.sqrt(2 * jnp.pi) * self.sigma) + g2 = jnp.exp(-(x - self.sigma) ** 2 / (2 * jnp.power(self.s * self.sigma, 2)) + ) / (jnp.sqrt(2 * jnp.pi) * self.s * self.sigma) + g3 = jnp.exp(-(x + self.sigma) ** 2 / (2 * jnp.power(self.s * self.sigma, 2)) + ) / (jnp.sqrt(2 * jnp.pi) * self.s * self.sigma) + dx = g1 * (1. + self.h) - g2 * self.h - g3 * self.h + return self.scale * dx + + def __repr__(self): + return f'{self.__class__.__name__}(h={self.h}, s={self.s}, sigma={self.sigma}, scale={self.scale})' + + +def multi_gaussian_grad( + x: Union[jax.Array, Array], + h: float = 0.15, + s: float = 6.0, + sigma: float = 0.5, + scale: float = 0.5, +): + r"""Spike function with the multi-Gaussian gradient function [1]_. + + The forward function: + + .. math:: + + g(x) = \begin{cases} + 1, & x \geq 0 \\ + 0, & x < 0 \\ + \end{cases} + + Backward function: + + .. math:: + + \begin{array}{l} + g'(x)=(1+h){{{\mathcal{N}}}}(x, 0, {\sigma }^{2}) + -h{{{\mathcal{N}}}}(x, \sigma,{(s\sigma )}^{2})- + h{{{\mathcal{N}}}}(x, -\sigma ,{(s\sigma )}^{2}) + \end{array} + + + .. plot:: + :include-source: True + + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> import matplotlib.pyplot as plt + >>> xs = bm.linspace(-3, 3, 1000) + >>> bp.visualize.get_figure(1, 1, 4, 6) + >>> grads = bm.vector_grad(bm.surrogate.multi_gaussian_grad)(xs) + >>> plt.plot(bm.as_numpy(xs), bm.as_numpy(grads)) + >>> plt.show() + + Parameters:: + + x: jax.Array, Array + The input data. + h: float + The hyper-parameters of approximate function + s: float + The hyper-parameters of approximate function + sigma: float + The gaussian sigma. + scale: float + The gradient scale. + + Returns:: + + out: jax.Array + The spiking state. + + References:: + + .. [1] Yin, B., Corradi, F. & Bohté, S.M. Accurate and efficient time-domain classification with adaptive spiking recurrent neural networks. Nat Mach Intell 3, 905–913 (2021). + """ + return MultiGaussianGrad(h=h, s=s, sigma=sigma, scale=scale)(x) + + +class InvSquareGrad(Surrogate): + """Judge spiking state with the inverse-square surrogate gradient function. + + See Also:: + + inv_square_grad + """ + + def __init__(self, alpha=100.): + super().__init__() + self.alpha = alpha + + def surrogate_grad(self, x): + dx = 1. / (self.alpha * jnp.abs(x) + 1.0) ** 2 + return dx + + def __repr__(self): + return f'{self.__class__.__name__}(alpha={self.alpha})' + + +def inv_square_grad( + x: Union[jax.Array, Array], + alpha: float = 100. +): + r"""Spike function with the inverse-square surrogate gradient. + + Forward function: + + .. math:: + + g(x) = \begin{cases} + 1, & x \geq 0 \\ + 0, & x < 0 \\ + \end{cases} + + Backward function: + + .. math:: + + g'(x) = \frac{1}{(\alpha * |x| + 1.) ^ 2} + + + .. plot:: + :include-source: True + + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> import matplotlib.pyplot as plt + >>> xs = bm.linspace(-1, 1, 1000) + >>> for alpha in [1., 10., 100.]: + >>> grads = bm.vector_grad(bm.surrogate.inv_square_grad)(xs, alpha) + >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha)) + >>> plt.legend() + >>> plt.show() + + Parameters:: + + x: jax.Array, Array + The input data. + alpha: float + Parameter to control smoothness of gradient + + Returns:: + + out: jax.Array + The spiking state. + """ + return InvSquareGrad(alpha=alpha)(x) + + +class SlayerGrad(Surrogate): + """Judge spiking state with the slayer surrogate gradient function. + + See Also:: + + slayer_grad + """ + + def __init__(self, alpha=1.): + super().__init__() + self.alpha = alpha + + def surrogate_grad(self, x): + dx = jnp.exp(-self.alpha * jnp.abs(x)) + return dx + + def __repr__(self): + return f'{self.__class__.__name__}(alpha={self.alpha})' + + +def slayer_grad( + x: Union[jax.Array, Array], + alpha: float = 1. +): + r"""Spike function with the slayer surrogate gradient function. + + Forward function: + + .. math:: + + g(x) = \begin{cases} + 1, & x \geq 0 \\ + 0, & x < 0 \\ + \end{cases} + + Backward function: + + .. math:: + + g'(x) = \exp(-\alpha |x|) + + + .. plot:: + :include-source: True + + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> import matplotlib.pyplot as plt + >>> bp.visualize.get_figure(1, 1, 4, 6) + >>> xs = bm.linspace(-3, 3, 1000) + >>> for alpha in [0.5, 1., 2., 4.]: + >>> grads = bm.vector_grad(bm.surrogate.slayer_grad)(xs, alpha) + >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha)) + >>> plt.legend() + >>> plt.show() + + Parameters:: + + x: jax.Array, Array + The input data. + alpha: float + Parameter to control smoothness of gradient + + Returns:: + + out: jax.Array + The spiking state. + + References:: + + .. [1] Shrestha, S. B. & Orchard, G. Slayer: spike layer error reassignment in time. In Advances in Neural Information Processing Systems Vol. 31, 1412–1421 (NeurIPS, 2018). + """ + return SlayerGrad(alpha=alpha)(x) diff --git a/changelog.md b/changelog.md index eb075544..d4650d7f 100644 --- a/changelog.md +++ b/changelog.md @@ -1,566 +1,605 @@ # Changelog -## Version 2.7.7 - -**Release Date:** March 12, 2026 - -This release migrates to the new `brainevent` binary backend API, fixes numerous bugs across neuron models, synapse dynamics, ODE integrators, and object transforms, and raises the minimum Python version to 3.11. - -### Breaking Changes - -#### Python Version Requirement -- **Raised**: Minimum Python version from 3.10 to **3.11** - - Dropped Python 3.10 support from classifiers and build metadata - -#### brainevent API Migration (#817) -- **Updated**: Event-driven operations to use the new `brainevent` binary backend - - `EventArray` replaced by `BinaryArray` in all event CSR and jitconn operations - - `JITCHomoR` replaced by `JITCScalarR` for homogeneous JIT connectivity - - STDP weight update functions renamed: - - `dense_on_pre` / `dense_on_post` → `update_dense_on_binary_pre` / `update_dense_on_binary_post` - - `csr_on_pre` / `csr2csc_on_post` → `update_csr_on_binary_pre` / `update_csr_on_binary_post` - - Affected layers: `Dense`, `AllToAll`, `MaskedLinear`, and all CSR-based layers - -### Bug Fixes - -#### Neuron Models -- **Fixed**: All LIF-family neuron models (`LifLTC`, `LifRefLTC`, `ExpIFLTC`, `ExpIFRefLTC`, `AdExIFLTC`, `AdExIFRefLTC`, `QuaIFLTC`, `QuaIFRefLTC`, `AdQuaIFLTC`, `AdQuaIFRefLTC`, `GifLTC`, `GifRefLTC`) now raise descriptive `ValueError` messages for invalid `spk_reset` modes instead of bare `ValueError` - -#### Synapse Models -- **Fixed**: Synaptic variable updates in `DualExpon`, `Alpha`, and `NMDA` models - - Replaced in-place `+=` operator with explicit `.value` assignment to prevent tracing issues under JAX transformations - -#### ODE Integrators -- **Fixed**: `BogackiShampine` adaptive Runge-Kutta coefficient typo — corrected `'4/0'` to `'4/9'` in the Butcher tableau -- **Fixed**: `set_default_odeint` was incorrectly assigning to `_DEFAULT_ODE_METHOD` instead of `_DEFAULT_DDE_METHOD` - -#### Object Transforms -- **Fixed**: `Variable.varshape` used `self.batch_size` instead of `self.batch_axis` when computing non-batch shape dimensions -- **Fixed**: `NodeDict.update()` and `VarDict.update()` referenced `args[1]` instead of `arg[1]` when updating from tuple arguments -- **Fixed**: `NodeDict.__setitem__` duplicate key error message incorrectly displayed the new value instead of the existing one - -#### Training -- **Fixed**: `BPTrainer` did not raise `NoLongerSupportError` when `seed` parameter was passed (missing `raise` keyword) -- **Fixed**: `BPTrainer` metric aggregation switched from `jnp.mean(bm.as_jax(...))` to `np.mean(np.asarray(...))` to avoid unnecessary JAX tracing overhead - -#### Core -- **Fixed**: `ShardedArray.value` setter used `_check_tracer()` instead of direct `_value` access -- **Fixed**: `_slice_to_num` did not handle negative step values and raised no error on zero step -- **Fixed**: `load_state` crashed with `KeyError` when state dict contained missing keys; now correctly reports them as missing - -#### Code Quality -- **Improved**: Narrowed bare `except` clauses to specific exception types (`ValueError`, `TypeError`, `ImportError`, `ModuleNotFoundError`) in convolution layers and Flax interoperation module - -### Dependencies - -- **Updated**: `brainevent` from `>=0.0.4` to `>=0.0.7` -- **Updated**: `braintools` version spec corrected to `>=0.0.9` -- **Updated**: `numpy` minimum version set to `>=1.15` -- **Updated**: `brainpy_state` from `>=0.0.2` to `>=0.0.3` - -### CI/CD - -- **Updated**: `actions/upload-artifact` from v6 to v7 (#815) -- **Updated**: `actions/download-artifact` from v7 to v8 (#816) - - ---- - -## Version 2.7.6 - -**Release Date:** January 21, 2026 - -This is a maintenance release that enhances JAX compatibility and improves CI stability across platforms. - -### Bug Fixes - -#### JAX 0.9.0 Compatibility (#813) -- **Fixed**: ODE integrator default time parameter handling - - Ensured `t` keyword argument defaults to 0 in `_call_integral` method - - Prevents errors when time parameter is not explicitly provided -- **Updated**: Backend imports for compatibility with JAX >= 0.8.0 - - Updated `brainpy.math.environment` module to handle JAX backend changes - - Improved compatibility layer for future JAX versions - -#### ODE Integrators -- **Fixed**: Explicit Runge-Kutta methods formatting in build method -- **Impact**: Ensures stable numerical integration across different JAX versions - -### Dependencies - -#### Updated Dependencies -- **Updated**: `brainpy_state` from 0.0.1 to 0.0.3 -- **Enhanced**: README documentation with comprehensive module overview and installation instructions - - ---- - -## Version 2.7.5 - -**Release Date:** December 25, 2025 - -This release focuses on improving JAX compatibility and code quality through comprehensive refactoring. - -### Major Changes - -#### JAX Compatibility Enhancement (#809) -- **Updated**: Refined JIT wrappers for compatibility with JAX >= 0.8.2 - - Refactored JIT handling across 85 files - - Updated object transformation modules for new JAX behavior - - Improved JIT compilation stability and performance -- **Added**: Initial `brainpy_state` module infrastructure - - Created new state management module with README documentation - - Set up module structure for future state-based functionality -- **Updated**: JAX backend integration - - Modernized backend import patterns - - Enhanced compatibility with JAX's evolving API +## Version 2.7.8 +**Release Date:** April 18, 2026 ---- - -## Version 2.7.4 - -**Release Date:** December 2025 - -This release focuses on simplifying the project structure by removing the experimental `brainpy.state` module and consolidating documentation. - -### Major Changes - -#### Removed `brainpy.state` Module (#806) -- **Removed**: The entire `brainpy.state` module has been deleted - - This includes all state-based neuron models (LIF variants, Izhikevich, HH) - - Removed synapse models, projections, readouts, and STP implementations - - Removed all associated test files -- **Recommendation**: Users should use the [brainpy.state](https://github.com/chaobrain/brainpy.state) library directly for state-based neural network simulations - - -#### Decouple the ``brainpy`` context with ``brainstate`` context - -- **Updated**: `brainpy.math.defaults` is totally decoupled with `brainstate` context management - - - ---- - -## Version 2.7.3 - -**Release Date:** December 2025 - -This is a bug fix release that resolves critical issues with `bm.for_loop` and improves CI stability. - -### Bug Fixes - -#### `bm.for_loop` jit Parameter Fix -- **Fixed**: The `jit` parameter in `bm.for_loop` was accepted but never used - passing `jit=False` had no effect -- **Implementation**: When `jit=False`, the call is now properly wrapped in `jax.disable_jit()` context manager -- **Impact**: Users can now debug code with `jit=False` to see actual values instead of JIT-compiled traces - -#### Zero-Length Scan Fix -- **Fixed**: `ValueError: zero-length scan is not supported in disable_jit() mode` when using `jit=False` with zero-length inputs -- **Implementation**: Automatically falls back to JIT mode for zero-length inputs with a warning -- **Impact**: Prevents crashes when `DSRunner.run(duration)` results in 0 time steps (e.g., `duration=0.5, dt=1.0`) - -#### Progress Bar Enhancement -- **Enhanced**: `progress_bar` parameter in `bm.for_loop()` and `bm.scan()` now supports advanced customization -- **New Features**: - - Accept `ProgressBar` instances for fine-grained control (freq, desc, count parameters) - - Accept integers as shorthand for frequency (e.g., `progress_bar=10` means update every 10 iterations) - - Full backward compatibility with existing `progress_bar=True/False` usage -- **Export**: Added `bm.ProgressBar` for easy access (`from brainpy.math import ProgressBar`) -- **Impact**: Aligns with brainstate API and enables better progress tracking customization - -#### Parameter Cleanup -- **Removed**: Unused parameters `remat` and `unroll_kwargs` from `bm.for_loop()` -- **Backward Compatibility**: `remat` parameter kept in `LoopOverTime.__init__()` with deprecation warning -- **Fixes**: Resolved TypeErrors in `DSRunner` and `LoopOverTime` that used these parameters - ---- - -## Version 2.7.2 - -**Release Date:** October 16, 2025 - -This is a maintenance release that improves JAX compatibility and documentation. - -### Bug Fixes - -#### JAX Compatibility -- **Updated**: Made compatible with JAX >= 0.8.0 -- **Fixed**: Updated imports and API usage for latest JAX versions -- **Impact**: Ensures BrainPy works correctly with the latest JAX releases +This maintenance release focuses on JAX compatibility cleanup, packaging improvements for `brainevent`, and documentation streamlining. It also removes duplicated internal JAXPR source-conversion code by relying on the shared `brainstate` implementation. ### Improvements -#### Documentation -- Updated documentation and CI configuration for better clarity -- Standardized test paths across the project -- Improved core concepts documentation -- Enhanced LIF neuron dynamics documentation (#800) -- Fixed documentation bugs - -#### Neural Network Classes -- Refactored neural network classes for better maintainability -- Updated progress bar parameters for simulations -- Improved code organization and structure - - ---- - -## Version 2.7.1 - -**Release Date:** October 2025 - -This is a feature release that introduces new neuron and synapse models in the state-based API (`brainpy.state`) and enhances the Dynamics base class with improved input handling. - -### Major Changes - -#### New Neuron Models (brainpy.state) -- **LIF (Leaky Integrate-and-Fire) Variants**: Added comprehensive set of LIF neuron models - - `LIF`: Basic LIF neuron with exponential synaptic input - - `LifRef`: LIF with refractory period - - `ExpIF`: Exponential Integrate-and-Fire neuron - - `ExpIFRef`: ExpIF with refractory period - - `AdExIF`: Adaptive Exponential Integrate-and-Fire neuron - - `AdExIFRef`: AdExIF with refractory period - - `QuaIF`: Quadratic Integrate-and-Fire neuron - - `QuaIFRef`: QuaIF with refractory period - - `AdQuaIF`: Adaptive Quadratic Integrate-and-Fire neuron - - `AdQuaIFRef`: AdQuaIF with refractory period - - `GifRef`: Generalized Integrate-and-Fire with refractory period - -- **Izhikevich Neuron Models**: Added new Izhikevich neuron implementations - - `Izhikevich`: Basic Izhikevich neuron model - - `IzhikevichRef`: Izhikevich with refractory period - -- **Hodgkin-Huxley Model**: Added classic biophysical neuron model - - `HH`: Classic Hodgkin-Huxley model with Na+ and K+ channels - -#### New Synapse Models (brainpy.state) -- **BioNMDA**: Biological NMDA receptor with second-order kinetics - - Implements two-state cascade dynamics (x and g variables) - - Slower rise time compared to AMPA (biologically realistic) - - Comprehensive documentation with mathematical formulation +#### JAX Compatibility and Internal Refactoring +- **Updated**: Replaced version-specific JAX imports with `brainstate._compatible_import` across math compatibility, backend environment utilities, surrogate operators, and vmap-removal helpers + - Centralizes compatibility handling for `Primitive` and backend imports across newer JAX releases +- **Updated**: `brainpy.integrators` now imports `jaxpr_to_python_code` from `brainstate.transform` + - Removed the duplicated local `brainpy.integrators._jaxpr_to_source_code` implementation + - Keeps integrator math-expression generation aligned with the shared `brainstate` implementation +- **Updated**: `bm.trapz` now consistently maps to `jax.scipy.integrate.trapezoid` -### Features +### Packaging -#### Model Implementation -- All new models use the brainstate ecosystem (HiddenState, ShortTermState, LongTermState) -- Proper unit support with brainunit integration -- Exponential Euler integration for numerical stability -- Batch processing support across all models -- Consistent API design following BrainPy v2.7+ architecture - -#### Dynamics Class Enhancements -- Enhanced input handling capabilities in the Dynamics base class -- Added new properties for better state management -- Improved integration with brainstate framework -- Refactored to use public methods instead of private counterparts for clarity - -#### Documentation -- Added comprehensive Examples sections to all neuron classes in `_lif.py` -- Each example includes: - - Import statements for required modules - - Basic usage with parameter specifications - - State initialization examples - - Update and spike generation examples - - Network integration with `brainstate.nn.Sequential` - - Notes highlighting key features -- All 13 neuron classes in `_lif.py` now have complete documentation -- Simplified documentation paths by removing 'core-concepts' and 'quickstart' prefixes in index.rst - -### Bug Fixes -- Fixed import paths in `_base.py`: changed references from brainstate to brainpy for consistency (057b872d) -- Fixed test suite issues (95ec2037) -- Fixed test suite for proper unit handling in synapse models - - -### Notes -- This release significantly expands the `brainpy.state` module with biologically realistic neuron and synapse models -- All new models are fully compatible with the brainstate ecosystem -- Enhanced documentation provides clear usage examples for all models -- The Dynamics class refactoring improves the foundation for future state-based model development - - - - -## Version 3.0.1 - -**Release Date:** October 2025 - -This is a patch release focusing on documentation improvements and module structure cleanup following the 3.0.0 release. - -### Major Changes - -#### Module Renaming -- **BREAKING CHANGE**: Renamed `brainpy.state_based` module to `brainpy.state` - - All functionality previously in `brainpy.state_based` is now accessible via `brainpy.state` - - Users should update imports from `brainpy.state_based` to `brainpy.state` - - This change provides a cleaner, more intuitive API structure - -#### Code Structure Cleanup -- **Removed `brainpy.version2` module**: All BrainPy 2.x functionality has been consolidated - - The `version2` namespace has been removed from the codebase - - All version2 functionality is now directly accessible through the main `brainpy` module - - Version-specific imports are no longer needed - -### Documentation - -#### Documentation Reorganization -- Renamed `docs_version2` to `docs_classic` for BrainPy 2.x documentation -- Renamed `docs_state_based` to `docs_state` for BrainPy 3.x documentation -- Renamed `examples_version2` to `examples_classic` for consistency -- Renamed `examples_state_based` to `examples_state` for clarity - -#### Documentation Updates -- Updated all documentation references to use `brainpy.state` instead of `brainpy.state_based` (#791, #790) -- Updated API documentation structure for improved clarity -- Simplified API reference pages by removing redundant content -- Updated card links and descriptions for `brainpy.state` APIs -- Improved quickstart tutorial (5min-tutorial.ipynb) with clearer examples -- Updated core concepts documentation to reflect new module structure -- Enhanced tutorials with corrected module references -- Updated all example files to use new module structure - -#### Examples Updates -- Updated simulation examples (EI networks, COBA, CUBA models) to use new API -- Updated training examples (surrogate gradient training, MNIST models) with correct imports -- Updated gamma oscillation examples with proper module references - -### Bug Fixes - -#### Testing -- Removed redundant test for abstract Neuron class that was causing conflicts (d06bb47f) - -### Migration Guide - -For users upgrading from BrainPy 3.0.0: - -1. **Update module imports**: Replace `brainpy.state_based` with `brainpy.state` - ```python - - # New code (BrainPy 3.0.1) - from brainpy.state import LIF, Expon - ``` - -2. **Remove version2 references**: If you were using `brainpy.version2`, migrate to the main `brainpy` module - ```python - # Old code (not recommended) - import brainpy.version2 as bp - - # New code - import brainpy as bp - ``` - -3. **Update documentation references**: If you're linking to documentation, use the new paths: - - Classic docs: `docs_classic/` (formerly `docs_version2/`) - - State-based docs: `docs_state/` (formerly `docs_state_based/`) - -### Notes -- This release maintains full backward compatibility with BrainPy 3.0.0 except for the module naming changes -- The `brainpy.state_based` to `brainpy.state` rename provides a cleaner API and better reflects the module's purpose -- Documentation is now better organized with clear separation between classic (2.x) and state-based (3.x) APIs - - - - -## Version 3.0.0 - -**Release Date:** October 2025 - -This is a major release with significant architectural changes and improvements. BrainPy 3.0.0 introduces a new API design while maintaining backward compatibility through the `brainpy` module. - -### Major Changes - -#### Architecture Reorganization -- **BREAKING CHANGE**: All existing BrainPy 2.x functionality has been moved to `brainpy` module - - Users can migrate existing code by replacing `import brainpy` with `import brainpy as brainpy` - - The old `brainpy._src` module structure has been completely reorganized into `brainpy` - - All submodules (math, dyn, dnn, etc.) are now under `brainpy.*` - -#### New Core API (brainpy.*) -- Introduced simplified, streamlined API in the main `brainpy` namespace -- New core modules include: - - Base classes for neurons and synapses - - LIF (Leaky Integrate-and-Fire) neuron models - - Exponential synapse models - - Synaptic projection modules - - Short-term plasticity (STP) models - - Input current generators - - Readout layers - - Error handling utilities - -### Dependencies -- **Updated**: `brainstate>=0.2.0` (was `>=0.1.0`) -- **Updated**: `brainevent>=0.0.4` (new requirement) -- **Updated**: `braintools>=0.0.9` (integrated into brainpy) -- **Removed**: Hard dependency on `taichi` and `numba` - now optional -- **Updated**: JAX compatibility improvements for version 0.5.0+ - -### Features - -#### Integration of Brain Ecosystem Libraries -- Integrated `brainstate` for state management (#763) -- Integrated `brainevent` for event-driven computations (#771) -- Integrated `braintools` utilities and formatting (#769) - -#### Math Module Enhancements (version2.math) -- Added event-driven sparse matrix @ matrix operators (#613) -- Added `ein_rearrange`, `ein_reduce`, and `ein_repeat` functions (#590) -- Added `unflatten` function and `Unflatten` layer (#588) -- Added JIT weight matrix methods (Uniform & Normal) for `dnn.linear` (#673) -- Added JIT connect matrix method for `dnn.linear` (#672) -- Replaced math operators with `braintaichi` for better performance (#698) -- Support for custom operators using CuPy (#653) -- Taichi operators as default customized operators (#598) -- Enhanced taichi custom operator support with GPU backend (#655) -- Support for more than 8 parameters in taichi GPU operator customization (#642) -- Rebased operator customization using MLIR registration interface (#618) -- Added transparent taichi caches with clean caches function (#596) -- Support for taichi customized op with metal CPU backend (#579) -- Improved variable retrieval system (#589) - -#### Deep Learning (version2.dnn) -- Improved error handling in `dnn/linear` module (#704) -- Enhanced activation functions and layers - -#### Dynamics (version2.dyn) -- Refactored STDP weight update logic requiring `brainevent>=0.0.4` (#771) -- Fixed STDP and training workflows for JAX compatibility (#772) -- Enhanced dual exponential synapse model with `normalize` parameter -- Improved alpha synapse implementation -- Added `clear_input` in the `step_run` function (#601) - -#### Integrators (version2.integrators) -- Support for `Integrator.to_math_expr()` (#674) -- Fixed dtype checking during exponential Euler method -- Added `disable_jit` support in `brainpy.math.scan` (#606) -- Fixed `brainpy.math.scan` implementation (#604) - -#### Optimizers (version2.optim) -- Fixed AdamW optimizer initialization where "amsgrad" was used before being defined (#660) - -#### Tools & Utilities (version2.tools) -- Added `brainpy.tools.compose` and `brainpy.tools.pipe` functions (#624) - -### Bug Fixes - -#### JAX Compatibility -- Updated JAX import paths for compatibility with version 0.5.0+ (#722) -- Fixed compatibility issues with latest JAX versions (#691, #708, #716) -- Replaced `jax.experimental.host_callback` with `jax.pure_callback` (#670) -- Fixed `test_ndarray.py` for latest JAX version (#708) - -#### Math & Operations -- Fixed `CustomOpByNumba` with `multiple_results=True` (#671) -- Updated `CustomOpByNumba` to support JAX version >= 0.4.24 (#669) -- Fixed `brainpy.math.softplus` and `brainpy.dnn.SoftPlus` (#581) -- Fixed bugs in `truncated_normal` and added `TruncatedNormal` initialization (#583, #584, #585, #574, #575) -- Fixed autograd functionality (#687) -- Fixed order of return values in `__load_state__` (#749) - -#### Delay & Timing -- Fixed delay bugs including DelayVar in concat mode (#632, #650) -- Fixed wrong randomness in OU process input (#715) - -#### UI & Progress -- Fixed progress bar display and update issues (#683) -- Fixed incorrect verbose of `clear_name_cache()` (#681) - -#### Python Compatibility -- Replaced `collections.Iterable` with `collections.abc.Iterable` for Python 3.10+ (#677) -- Fixed surrogate gradient function for numpy 2.0 compatibility (#679) - -#### Interoperability -- Fixed Flax RNN interoperation (#665) -- Fixed issue with external library integration (#661, #662) - -#### Exception Handling -- Fixed exception handling for missing braintaichi module in dependency check (#746) - -### Testing & CI - -#### Python Support -- Added CI support for Python 3.12 (#705) -- Added CI support for Python 3.13 -- Updated supported Python versions: 3.10, 3.11, 3.12, 3.13 - -#### CI Improvements -- Updated GitHub Actions: - - `actions/setup-python` from 5 to 6 (#783) - - `actions/checkout` from 4 to 5 (#773) - - `actions/first-interaction` from 1 to 3 (#782) - - `actions/labeler` from 5 to 6 (#781) - - `actions/download-artifact` from 4 to 5 (#780) - - `actions/stale` from 9 to 10 (#779) - - `docker/build-push-action` from 5 to 6 (#678) -- Added greetings workflow and labeler configuration -- Enhanced issue templates and CI configurations +- **Updated**: Optional install extras now include matching `brainevent` extras + - `brainpy[cpu]` installs `brainevent[cpu]` + - `brainpy[cuda12]` installs `brainevent[cuda12]` + - `brainpy[cuda13]` installs `brainevent[cuda13]` + - `brainpy[tpu]` installs `brainevent[tpu]` +- **Impact**: Simplifies accelerator-specific installation and keeps event-backend dependencies aligned with the selected JAX platform ### Documentation -#### Major Documentation Overhaul -- Introduced new BrainPy 3.0 documentation and tutorials (#787) -- Added comprehensive documentation and examples for BrainPy 3.x (#785) -- Updated documentation links for BrainPy 3.0 and 2.0 (#786) -- Implemented dynamic configuration loading for Read the Docs (#784) -- Added Colab and Kaggle links for documentation notebooks (#614, #619) -- Added Chinese version of `operator_custom_with_cupy.ipynb` (#659) -- Fixed various documentation build issues and path references - -#### Citation & Acknowledgments -- Added BrainPy citation information (#770) -- Updated ACKNOWLEDGMENTS.md - -#### Installation -- Refined installation instructions (#767) -- Updated docstring and parameter formatting (#766) -- Updated README with ecosystem information - -### Performance & Memory Management -- Enabled `clear_buffer_memory()` to support clearing `array`, `compilation`, and `names` (#639) -- Cleaned taichi AOT caches and enabled `numpy_func_return` setting (#643) -- Made taichi caches more transparent (#596) -- Enabled BrainPy objects as pytree for direct use with `jax.jit` (#625) - -### Object-Oriented Transformations -- Standardized and generalized object-oriented transformations (#628) - -### Development & Contributing -- Updated CONTRIBUTING.md with new guidelines -- Added CODEOWNERS file -- Updated SECURITY.md -- License updated to Apache License 2.0 +- **Added**: Direct link to the external `brainpy.state` API reference in the docs API index (#822) +- **Removed**: Legacy `brainpy` and `brainpylib` historical changelog pages from the main documentation navigation +- **Removed**: Advanced operator-customization tutorial notebooks for CuPy, Numba, and Taichi from the documentation tree +- **Updated**: Refreshed quickstart simulation notebook content and documentation structure around `brainpy.state` -### Removed -- Removed Docker workflow -- Removed hard dependencies on `taichi` and `numba` (#635) -- Removed op register functionality (#700) -- Removed deprecated deprecation files and old module structure -- Removed unnecessary dependencies (#703) +### Testing and CI/CD -### Migration Guide - -For users upgrading from BrainPy 2.6.x: - -1. **Keep using BrainPy 2.x API**: Replace imports with `brainpy` - ```python - # Old code (BrainPy 2.x) - import brainpy as bp - - # New code (BrainPy 3.0 with backward compatibility) - import brainpy as bp - ``` - -2. **Adopt new BrainPy 3.0 API**: Explore the simplified API in the main `brainpy` namespace for new projects - -3. **Update dependencies**: Ensure `brainstate>=0.2.0`, `brainevent>=0.0.4`, and `braintools>=0.0.9` are installed - -4. **Review breaking changes**: Check if your code uses any of the reorganized internal modules - -### Notes -- This release maintains backward compatibility through `brainpy` -- The new API in the main `brainpy` namespace represents the future direction of the library -- Documentation for both versions is available on Read the Docs +- **Improved**: Disabled JAX traceback filtering in `brainpy.dnn` linear tests to make failures easier to diagnose +- **Updated**: Raised the Sphinx upper bound for documentation builds from `<8.2.0` to `<9.1.0` (#820) +- **Updated**: Bumped `styfle/cancel-workflow-action` from `0.13.0` to `0.13.1` in CI (#819) +## Version 2.7.7 +**Release Date:** March 12, 2026 + +This release migrates to the new `brainevent` binary backend API, fixes numerous bugs across neuron models, synapse dynamics, ODE integrators, and object transforms, and raises the minimum Python version to 3.11. + +### Breaking Changes + +#### Python Version Requirement +- **Raised**: Minimum Python version from 3.10 to **3.11** + - Dropped Python 3.10 support from classifiers and build metadata + +#### brainevent API Migration (#817) +- **Updated**: Event-driven operations to use the new `brainevent` binary backend + - `EventArray` replaced by `BinaryArray` in all event CSR and jitconn operations + - `JITCHomoR` replaced by `JITCScalarR` for homogeneous JIT connectivity + - STDP weight update functions renamed: + - `dense_on_pre` / `dense_on_post` → `update_dense_on_binary_pre` / `update_dense_on_binary_post` + - `csr_on_pre` / `csr2csc_on_post` → `update_csr_on_binary_pre` / `update_csr_on_binary_post` + - Affected layers: `Dense`, `AllToAll`, `MaskedLinear`, and all CSR-based layers + +### Bug Fixes + +#### Neuron Models +- **Fixed**: All LIF-family neuron models (`LifLTC`, `LifRefLTC`, `ExpIFLTC`, `ExpIFRefLTC`, `AdExIFLTC`, `AdExIFRefLTC`, `QuaIFLTC`, `QuaIFRefLTC`, `AdQuaIFLTC`, `AdQuaIFRefLTC`, `GifLTC`, `GifRefLTC`) now raise descriptive `ValueError` messages for invalid `spk_reset` modes instead of bare `ValueError` + +#### Synapse Models +- **Fixed**: Synaptic variable updates in `DualExpon`, `Alpha`, and `NMDA` models + - Replaced in-place `+=` operator with explicit `.value` assignment to prevent tracing issues under JAX transformations + +#### ODE Integrators +- **Fixed**: `BogackiShampine` adaptive Runge-Kutta coefficient typo — corrected `'4/0'` to `'4/9'` in the Butcher tableau +- **Fixed**: `set_default_odeint` was incorrectly assigning to `_DEFAULT_ODE_METHOD` instead of `_DEFAULT_DDE_METHOD` + +#### Object Transforms +- **Fixed**: `Variable.varshape` used `self.batch_size` instead of `self.batch_axis` when computing non-batch shape dimensions +- **Fixed**: `NodeDict.update()` and `VarDict.update()` referenced `args[1]` instead of `arg[1]` when updating from tuple arguments +- **Fixed**: `NodeDict.__setitem__` duplicate key error message incorrectly displayed the new value instead of the existing one + +#### Training +- **Fixed**: `BPTrainer` did not raise `NoLongerSupportError` when `seed` parameter was passed (missing `raise` keyword) +- **Fixed**: `BPTrainer` metric aggregation switched from `jnp.mean(bm.as_jax(...))` to `np.mean(np.asarray(...))` to avoid unnecessary JAX tracing overhead + +#### Core +- **Fixed**: `ShardedArray.value` setter used `_check_tracer()` instead of direct `_value` access +- **Fixed**: `_slice_to_num` did not handle negative step values and raised no error on zero step +- **Fixed**: `load_state` crashed with `KeyError` when state dict contained missing keys; now correctly reports them as missing + +#### Code Quality +- **Improved**: Narrowed bare `except` clauses to specific exception types (`ValueError`, `TypeError`, `ImportError`, `ModuleNotFoundError`) in convolution layers and Flax interoperation module + +### Dependencies + +- **Updated**: `brainevent` from `>=0.0.4` to `>=0.0.7` +- **Updated**: `braintools` version spec corrected to `>=0.0.9` +- **Updated**: `numpy` minimum version set to `>=1.15` +- **Updated**: `brainpy_state` from `>=0.0.2` to `>=0.0.3` + +### CI/CD + +- **Updated**: `actions/upload-artifact` from v6 to v7 (#815) +- **Updated**: `actions/download-artifact` from v7 to v8 (#816) + + +--- + +## Version 2.7.6 + +**Release Date:** January 21, 2026 + +This is a maintenance release that enhances JAX compatibility and improves CI stability across platforms. + +### Bug Fixes + +#### JAX 0.9.0 Compatibility (#813) +- **Fixed**: ODE integrator default time parameter handling + - Ensured `t` keyword argument defaults to 0 in `_call_integral` method + - Prevents errors when time parameter is not explicitly provided +- **Updated**: Backend imports for compatibility with JAX >= 0.8.0 + - Updated `brainpy.math.environment` module to handle JAX backend changes + - Improved compatibility layer for future JAX versions + +#### ODE Integrators +- **Fixed**: Explicit Runge-Kutta methods formatting in build method +- **Impact**: Ensures stable numerical integration across different JAX versions + +### Dependencies + +#### Updated Dependencies +- **Updated**: `brainpy_state` from 0.0.1 to 0.0.3 +- **Enhanced**: README documentation with comprehensive module overview and installation instructions + + +--- + +## Version 2.7.5 + +**Release Date:** December 25, 2025 + +This release focuses on improving JAX compatibility and code quality through comprehensive refactoring. + +### Major Changes + +#### JAX Compatibility Enhancement (#809) +- **Updated**: Refined JIT wrappers for compatibility with JAX >= 0.8.2 + - Refactored JIT handling across 85 files + - Updated object transformation modules for new JAX behavior + - Improved JIT compilation stability and performance +- **Added**: Initial `brainpy_state` module infrastructure + - Created new state management module with README documentation + - Set up module structure for future state-based functionality +- **Updated**: JAX backend integration + - Modernized backend import patterns + - Enhanced compatibility with JAX's evolving API + + +--- + +## Version 2.7.4 + +**Release Date:** December 2025 + +This release focuses on simplifying the project structure by removing the experimental `brainpy.state` module and consolidating documentation. + +### Major Changes + +#### Removed `brainpy.state` Module (#806) +- **Removed**: The entire `brainpy.state` module has been deleted + - This includes all state-based neuron models (LIF variants, Izhikevich, HH) + - Removed synapse models, projections, readouts, and STP implementations + - Removed all associated test files +- **Recommendation**: Users should use the [brainpy.state](https://github.com/chaobrain/brainpy.state) library directly for state-based neural network simulations + + +#### Decouple the ``brainpy`` context with ``brainstate`` context + +- **Updated**: `brainpy.math.defaults` is totally decoupled with `brainstate` context management + + + +--- + +## Version 2.7.3 + +**Release Date:** December 2025 + +This is a bug fix release that resolves critical issues with `bm.for_loop` and improves CI stability. + +### Bug Fixes + +#### `bm.for_loop` jit Parameter Fix +- **Fixed**: The `jit` parameter in `bm.for_loop` was accepted but never used - passing `jit=False` had no effect +- **Implementation**: When `jit=False`, the call is now properly wrapped in `jax.disable_jit()` context manager +- **Impact**: Users can now debug code with `jit=False` to see actual values instead of JIT-compiled traces + +#### Zero-Length Scan Fix +- **Fixed**: `ValueError: zero-length scan is not supported in disable_jit() mode` when using `jit=False` with zero-length inputs +- **Implementation**: Automatically falls back to JIT mode for zero-length inputs with a warning +- **Impact**: Prevents crashes when `DSRunner.run(duration)` results in 0 time steps (e.g., `duration=0.5, dt=1.0`) + +#### Progress Bar Enhancement +- **Enhanced**: `progress_bar` parameter in `bm.for_loop()` and `bm.scan()` now supports advanced customization +- **New Features**: + - Accept `ProgressBar` instances for fine-grained control (freq, desc, count parameters) + - Accept integers as shorthand for frequency (e.g., `progress_bar=10` means update every 10 iterations) + - Full backward compatibility with existing `progress_bar=True/False` usage +- **Export**: Added `bm.ProgressBar` for easy access (`from brainpy.math import ProgressBar`) +- **Impact**: Aligns with brainstate API and enables better progress tracking customization + +#### Parameter Cleanup +- **Removed**: Unused parameters `remat` and `unroll_kwargs` from `bm.for_loop()` +- **Backward Compatibility**: `remat` parameter kept in `LoopOverTime.__init__()` with deprecation warning +- **Fixes**: Resolved TypeErrors in `DSRunner` and `LoopOverTime` that used these parameters + +--- + +## Version 2.7.2 + +**Release Date:** October 16, 2025 + +This is a maintenance release that improves JAX compatibility and documentation. + +### Bug Fixes + +#### JAX Compatibility +- **Updated**: Made compatible with JAX >= 0.8.0 +- **Fixed**: Updated imports and API usage for latest JAX versions +- **Impact**: Ensures BrainPy works correctly with the latest JAX releases + +### Improvements + +#### Documentation +- Updated documentation and CI configuration for better clarity +- Standardized test paths across the project +- Improved core concepts documentation +- Enhanced LIF neuron dynamics documentation (#800) +- Fixed documentation bugs + +#### Neural Network Classes +- Refactored neural network classes for better maintainability +- Updated progress bar parameters for simulations +- Improved code organization and structure + + +--- + +## Version 2.7.1 + +**Release Date:** October 2025 + +This is a feature release that introduces new neuron and synapse models in the state-based API (`brainpy.state`) and enhances the Dynamics base class with improved input handling. + +### Major Changes + +#### New Neuron Models (brainpy.state) +- **LIF (Leaky Integrate-and-Fire) Variants**: Added comprehensive set of LIF neuron models + - `LIF`: Basic LIF neuron with exponential synaptic input + - `LifRef`: LIF with refractory period + - `ExpIF`: Exponential Integrate-and-Fire neuron + - `ExpIFRef`: ExpIF with refractory period + - `AdExIF`: Adaptive Exponential Integrate-and-Fire neuron + - `AdExIFRef`: AdExIF with refractory period + - `QuaIF`: Quadratic Integrate-and-Fire neuron + - `QuaIFRef`: QuaIF with refractory period + - `AdQuaIF`: Adaptive Quadratic Integrate-and-Fire neuron + - `AdQuaIFRef`: AdQuaIF with refractory period + - `GifRef`: Generalized Integrate-and-Fire with refractory period + +- **Izhikevich Neuron Models**: Added new Izhikevich neuron implementations + - `Izhikevich`: Basic Izhikevich neuron model + - `IzhikevichRef`: Izhikevich with refractory period + +- **Hodgkin-Huxley Model**: Added classic biophysical neuron model + - `HH`: Classic Hodgkin-Huxley model with Na+ and K+ channels + +#### New Synapse Models (brainpy.state) +- **BioNMDA**: Biological NMDA receptor with second-order kinetics + - Implements two-state cascade dynamics (x and g variables) + - Slower rise time compared to AMPA (biologically realistic) + - Comprehensive documentation with mathematical formulation + +### Features + +#### Model Implementation +- All new models use the brainstate ecosystem (HiddenState, ShortTermState, LongTermState) +- Proper unit support with brainunit integration +- Exponential Euler integration for numerical stability +- Batch processing support across all models +- Consistent API design following BrainPy v2.7+ architecture + +#### Dynamics Class Enhancements +- Enhanced input handling capabilities in the Dynamics base class +- Added new properties for better state management +- Improved integration with brainstate framework +- Refactored to use public methods instead of private counterparts for clarity + +#### Documentation +- Added comprehensive Examples sections to all neuron classes in `_lif.py` +- Each example includes: + - Import statements for required modules + - Basic usage with parameter specifications + - State initialization examples + - Update and spike generation examples + - Network integration with `brainstate.nn.Sequential` + - Notes highlighting key features +- All 13 neuron classes in `_lif.py` now have complete documentation +- Simplified documentation paths by removing 'core-concepts' and 'quickstart' prefixes in index.rst + +### Bug Fixes +- Fixed import paths in `_base.py`: changed references from brainstate to brainpy for consistency (057b872d) +- Fixed test suite issues (95ec2037) +- Fixed test suite for proper unit handling in synapse models + + +### Notes +- This release significantly expands the `brainpy.state` module with biologically realistic neuron and synapse models +- All new models are fully compatible with the brainstate ecosystem +- Enhanced documentation provides clear usage examples for all models +- The Dynamics class refactoring improves the foundation for future state-based model development + + + + +## Version 3.0.1 + +**Release Date:** October 2025 + +This is a patch release focusing on documentation improvements and module structure cleanup following the 3.0.0 release. + +### Major Changes + +#### Module Renaming +- **BREAKING CHANGE**: Renamed `brainpy.state_based` module to `brainpy.state` + - All functionality previously in `brainpy.state_based` is now accessible via `brainpy.state` + - Users should update imports from `brainpy.state_based` to `brainpy.state` + - This change provides a cleaner, more intuitive API structure + +#### Code Structure Cleanup +- **Removed `brainpy.version2` module**: All BrainPy 2.x functionality has been consolidated + - The `version2` namespace has been removed from the codebase + - All version2 functionality is now directly accessible through the main `brainpy` module + - Version-specific imports are no longer needed + +### Documentation + +#### Documentation Reorganization +- Renamed `docs_version2` to `docs_classic` for BrainPy 2.x documentation +- Renamed `docs_state_based` to `docs_state` for BrainPy 3.x documentation +- Renamed `examples_version2` to `examples_classic` for consistency +- Renamed `examples_state_based` to `examples_state` for clarity + +#### Documentation Updates +- Updated all documentation references to use `brainpy.state` instead of `brainpy.state_based` (#791, #790) +- Updated API documentation structure for improved clarity +- Simplified API reference pages by removing redundant content +- Updated card links and descriptions for `brainpy.state` APIs +- Improved quickstart tutorial (5min-tutorial.ipynb) with clearer examples +- Updated core concepts documentation to reflect new module structure +- Enhanced tutorials with corrected module references +- Updated all example files to use new module structure + +#### Examples Updates +- Updated simulation examples (EI networks, COBA, CUBA models) to use new API +- Updated training examples (surrogate gradient training, MNIST models) with correct imports +- Updated gamma oscillation examples with proper module references + +### Bug Fixes + +#### Testing +- Removed redundant test for abstract Neuron class that was causing conflicts (d06bb47f) + +### Migration Guide + +For users upgrading from BrainPy 3.0.0: + +1. **Update module imports**: Replace `brainpy.state_based` with `brainpy.state` + ```python + + # New code (BrainPy 3.0.1) + from brainpy.state import LIF, Expon + ``` + +2. **Remove version2 references**: If you were using `brainpy.version2`, migrate to the main `brainpy` module + ```python + # Old code (not recommended) + import brainpy.version2 as bp + + # New code + import brainpy as bp + ``` + +3. **Update documentation references**: If you're linking to documentation, use the new paths: + - Classic docs: `docs_classic/` (formerly `docs_version2/`) + - State-based docs: `docs_state/` (formerly `docs_state_based/`) + +### Notes +- This release maintains full backward compatibility with BrainPy 3.0.0 except for the module naming changes +- The `brainpy.state_based` to `brainpy.state` rename provides a cleaner API and better reflects the module's purpose +- Documentation is now better organized with clear separation between classic (2.x) and state-based (3.x) APIs + + + + +## Version 3.0.0 + +**Release Date:** October 2025 + +This is a major release with significant architectural changes and improvements. BrainPy 3.0.0 introduces a new API design while maintaining backward compatibility through the `brainpy` module. + +### Major Changes + +#### Architecture Reorganization +- **BREAKING CHANGE**: All existing BrainPy 2.x functionality has been moved to `brainpy` module + - Users can migrate existing code by replacing `import brainpy` with `import brainpy as brainpy` + - The old `brainpy._src` module structure has been completely reorganized into `brainpy` + - All submodules (math, dyn, dnn, etc.) are now under `brainpy.*` + +#### New Core API (brainpy.*) +- Introduced simplified, streamlined API in the main `brainpy` namespace +- New core modules include: + - Base classes for neurons and synapses + - LIF (Leaky Integrate-and-Fire) neuron models + - Exponential synapse models + - Synaptic projection modules + - Short-term plasticity (STP) models + - Input current generators + - Readout layers + - Error handling utilities + +### Dependencies +- **Updated**: `brainstate>=0.2.0` (was `>=0.1.0`) +- **Updated**: `brainevent>=0.0.4` (new requirement) +- **Updated**: `braintools>=0.0.9` (integrated into brainpy) +- **Removed**: Hard dependency on `taichi` and `numba` - now optional +- **Updated**: JAX compatibility improvements for version 0.5.0+ + +### Features + +#### Integration of Brain Ecosystem Libraries +- Integrated `brainstate` for state management (#763) +- Integrated `brainevent` for event-driven computations (#771) +- Integrated `braintools` utilities and formatting (#769) + +#### Math Module Enhancements (version2.math) +- Added event-driven sparse matrix @ matrix operators (#613) +- Added `ein_rearrange`, `ein_reduce`, and `ein_repeat` functions (#590) +- Added `unflatten` function and `Unflatten` layer (#588) +- Added JIT weight matrix methods (Uniform & Normal) for `dnn.linear` (#673) +- Added JIT connect matrix method for `dnn.linear` (#672) +- Replaced math operators with `braintaichi` for better performance (#698) +- Support for custom operators using CuPy (#653) +- Taichi operators as default customized operators (#598) +- Enhanced taichi custom operator support with GPU backend (#655) +- Support for more than 8 parameters in taichi GPU operator customization (#642) +- Rebased operator customization using MLIR registration interface (#618) +- Added transparent taichi caches with clean caches function (#596) +- Support for taichi customized op with metal CPU backend (#579) +- Improved variable retrieval system (#589) + +#### Deep Learning (version2.dnn) +- Improved error handling in `dnn/linear` module (#704) +- Enhanced activation functions and layers + +#### Dynamics (version2.dyn) +- Refactored STDP weight update logic requiring `brainevent>=0.0.4` (#771) +- Fixed STDP and training workflows for JAX compatibility (#772) +- Enhanced dual exponential synapse model with `normalize` parameter +- Improved alpha synapse implementation +- Added `clear_input` in the `step_run` function (#601) + +#### Integrators (version2.integrators) +- Support for `Integrator.to_math_expr()` (#674) +- Fixed dtype checking during exponential Euler method +- Added `disable_jit` support in `brainpy.math.scan` (#606) +- Fixed `brainpy.math.scan` implementation (#604) + +#### Optimizers (version2.optim) +- Fixed AdamW optimizer initialization where "amsgrad" was used before being defined (#660) + +#### Tools & Utilities (version2.tools) +- Added `brainpy.tools.compose` and `brainpy.tools.pipe` functions (#624) + +### Bug Fixes + +#### JAX Compatibility +- Updated JAX import paths for compatibility with version 0.5.0+ (#722) +- Fixed compatibility issues with latest JAX versions (#691, #708, #716) +- Replaced `jax.experimental.host_callback` with `jax.pure_callback` (#670) +- Fixed `test_ndarray.py` for latest JAX version (#708) + +#### Math & Operations +- Fixed `CustomOpByNumba` with `multiple_results=True` (#671) +- Updated `CustomOpByNumba` to support JAX version >= 0.4.24 (#669) +- Fixed `brainpy.math.softplus` and `brainpy.dnn.SoftPlus` (#581) +- Fixed bugs in `truncated_normal` and added `TruncatedNormal` initialization (#583, #584, #585, #574, #575) +- Fixed autograd functionality (#687) +- Fixed order of return values in `__load_state__` (#749) + +#### Delay & Timing +- Fixed delay bugs including DelayVar in concat mode (#632, #650) +- Fixed wrong randomness in OU process input (#715) + +#### UI & Progress +- Fixed progress bar display and update issues (#683) +- Fixed incorrect verbose of `clear_name_cache()` (#681) + +#### Python Compatibility +- Replaced `collections.Iterable` with `collections.abc.Iterable` for Python 3.10+ (#677) +- Fixed surrogate gradient function for numpy 2.0 compatibility (#679) + +#### Interoperability +- Fixed Flax RNN interoperation (#665) +- Fixed issue with external library integration (#661, #662) + +#### Exception Handling +- Fixed exception handling for missing braintaichi module in dependency check (#746) + +### Testing & CI + +#### Python Support +- Added CI support for Python 3.12 (#705) +- Added CI support for Python 3.13 +- Updated supported Python versions: 3.10, 3.11, 3.12, 3.13 + +#### CI Improvements +- Updated GitHub Actions: + - `actions/setup-python` from 5 to 6 (#783) + - `actions/checkout` from 4 to 5 (#773) + - `actions/first-interaction` from 1 to 3 (#782) + - `actions/labeler` from 5 to 6 (#781) + - `actions/download-artifact` from 4 to 5 (#780) + - `actions/stale` from 9 to 10 (#779) + - `docker/build-push-action` from 5 to 6 (#678) +- Added greetings workflow and labeler configuration +- Enhanced issue templates and CI configurations + +### Documentation + +#### Major Documentation Overhaul +- Introduced new BrainPy 3.0 documentation and tutorials (#787) +- Added comprehensive documentation and examples for BrainPy 3.x (#785) +- Updated documentation links for BrainPy 3.0 and 2.0 (#786) +- Implemented dynamic configuration loading for Read the Docs (#784) +- Added Colab and Kaggle links for documentation notebooks (#614, #619) +- Added Chinese version of `operator_custom_with_cupy.ipynb` (#659) +- Fixed various documentation build issues and path references + +#### Citation & Acknowledgments +- Added BrainPy citation information (#770) +- Updated ACKNOWLEDGMENTS.md + +#### Installation +- Refined installation instructions (#767) +- Updated docstring and parameter formatting (#766) +- Updated README with ecosystem information + +### Performance & Memory Management +- Enabled `clear_buffer_memory()` to support clearing `array`, `compilation`, and `names` (#639) +- Cleaned taichi AOT caches and enabled `numpy_func_return` setting (#643) +- Made taichi caches more transparent (#596) +- Enabled BrainPy objects as pytree for direct use with `jax.jit` (#625) + +### Object-Oriented Transformations +- Standardized and generalized object-oriented transformations (#628) + +### Development & Contributing +- Updated CONTRIBUTING.md with new guidelines +- Added CODEOWNERS file +- Updated SECURITY.md +- License updated to Apache License 2.0 + +### Removed +- Removed Docker workflow +- Removed hard dependencies on `taichi` and `numba` (#635) +- Removed op register functionality (#700) +- Removed deprecated deprecation files and old module structure +- Removed unnecessary dependencies (#703) + +### Migration Guide + +For users upgrading from BrainPy 2.6.x: + +1. **Keep using BrainPy 2.x API**: Replace imports with `brainpy` + ```python + # Old code (BrainPy 2.x) + import brainpy as bp + + # New code (BrainPy 3.0 with backward compatibility) + import brainpy as bp + ``` + +2. **Adopt new BrainPy 3.0 API**: Explore the simplified API in the main `brainpy` namespace for new projects + +3. **Update dependencies**: Ensure `brainstate>=0.2.0`, `brainevent>=0.0.4`, and `braintools>=0.0.9` are installed + +4. **Review breaking changes**: Check if your code uses any of the reorganized internal modules + +### Notes +- This release maintains backward compatibility through `brainpy` +- The new API in the main `brainpy` namespace represents the future direction of the library +- Documentation for both versions is available on Read the Docs + + +