Module adseq.bridges.flax_bridge
Functions
def superspike(x)-
Expand source code
@jax.custom_jvp def superspike(x): 'doi.dx/10.1162/neco_a_01086' return jnp.where(x < 0, 0.0, 1.0)doi.dx/10.1162/neco_a_01086
def superspike_jvp(primals, tangents)-
Expand source code
@superspike.defjvp def superspike_jvp(primals, tangents): (x,), (x_dot,) = primals, tangents primal_out = jnp.where(x < 0, 0.0, 1.0) tangent_out = x_dot / (jnp.abs(x)+1)**2 return primal_out, tangent_out
Classes
class DelayedStaticSynapse (dt: float,
tau_syn1_ms: float = 0.5,
tau_syn2_ms: float = 2.0,
max_delay: float = 20.0,
delay_init: jax.nn.initializers.Initializer | Callable[..., typing.Any] = <function normal.<locals>.init>,
queue: type[BaseQueue] = adseq.implementations.fifo_ring.FIFORing[4],
parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>,
name: str | None = None)-
Expand source code
class DelayedStaticSynapse(nn.Module): 'Plain delayed synapse for static data (non neuron produced) delays' dt: float tau_syn1_ms: float = 0.5 tau_syn2_ms: float = 2.0 max_delay: float = 20. delay_init: nn.initializers.Initializer = nn.initializers.normal(1.) delay_activation = lambda self, x: self.max_delay * (1 + nn.tanh(x)) queue: type[implementations.BaseQueue] = implementations.FIFORing.sized(4) # type: ignore def init_carry(self, s) -> DelaySynapseCarry: 'Example input s' syn = synapse2.mk_synapse2s(self.queue, vthres=0.0, tau_syn1_ms=self.tau_syn1_ms, tau_syn2_ms=self.tau_syn2_ms, dt_ms=self.dt, n=len(s), max_delay_ms=self.max_delay) return syn, 0 @nn.compact def __call__(self, carry: DelaySynapseCarry, s: jax.Array) -> tuple[DelaySynapseCarry, jax.Array]: 's is a binary indicator for spikes (1 meaning spike, 0 no input spike)' assert len(carry) == 2 syn, ts = carry delay = self.delay_activation(self.param('delay', self.delay_init, s.shape)) isyn = syn.isyn syn = syn.timestep_static_spike(t_ms=self.dt*ts, s=s, delay_ms=delay) return (syn, ts+1), isynPlain delayed synapse for static data (non neuron produced) delays
Ancestors
- flax.linen.module.Module
- flax.linen.module.ModuleBase
Class variables
var scope
Instance variables
var dt : floatvar max_delay : floatvar name : str | Nonevar parent : flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | Nonevar queue : type[BaseQueue]-
FIFORing(buffer, head, size)
var tau_syn1_ms : floatvar tau_syn2_ms : float
Methods
def delay_activation(self, x)-
Expand source code
delay_activation = lambda self, x: self.max_delay * (1 + nn.tanh(x)) def delay_init(key: Array,
shape: core.Shape,
dtype: DTypeLikeInexact | None = None,
out_sharding: OutShardingType = None) ‑> jax.Array-
Expand source code
def init(key: Array, shape: core.Shape, dtype: DTypeLikeInexact | None = dtype, out_sharding: OutShardingType = None) -> Array: dtype = dtypes.default_float_dtype() if dtype is None else dtype return random.normal(key, shape, dtype, out_sharding=out_sharding) * jnp.array(stddev, dtype) def init_carry(self, s) ‑> tuple[StaticMultiSynapse, int] | tuple[StaticMultiSynapse, int, jax.Array]-
Expand source code
def init_carry(self, s) -> DelaySynapseCarry: 'Example input s' syn = synapse2.mk_synapse2s(self.queue, vthres=0.0, tau_syn1_ms=self.tau_syn1_ms, tau_syn2_ms=self.tau_syn2_ms, dt_ms=self.dt, n=len(s), max_delay_ms=self.max_delay) return syn, 0Example input s
class DelayedThresholdSynapse (dt: float,
tau_syn1_ms: float = 0.5,
tau_syn2_ms: float = 2.0,
max_delay: float = 20.0,
vthres: float = 1.0,
delay_init: jax.nn.initializers.Initializer | Callable[..., typing.Any] = <function normal.<locals>.init>,
queue: type[BaseQueue] = adseq.implementations.fifo_ring.FIFORing[4],
parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>,
name: str | None = None)-
Expand source code
class DelayedThresholdSynapse(nn.Module): 'Plain delayed synapse that detects spikes given input voltage' dt: float tau_syn1_ms: float = 0.5 tau_syn2_ms: float = 2.0 max_delay: float = 20. vthres: float = 1.0 delay_init: nn.initializers.Initializer = nn.initializers.normal(1.) delay_activation = lambda self, x: self.max_delay * (1 + nn.tanh(x)) queue: type[implementations.BaseQueue] = implementations.FIFORing.sized(4) # type: ignore def init_carry(self, v:jax.Array, vnext: jax.Array|None=None) -> DelaySynapseCarry: 'if vnext is none, we delay one timestep' assert len(v.shape) == 1 syn = synapse2.mk_synapse2s(self.queue, vthres=self.vthres, tau_syn1_ms=self.tau_syn1_ms, tau_syn2_ms=self.tau_syn2_ms, dt_ms=self.dt, n=len(v), max_delay_ms=self.max_delay) if vnext is None: return syn, 0, 0*v # type: ignore else: assert vnext.shape == v.shape return syn, 0 # type: ignore @nn.compact def __call__(self, carry: DelaySynapseCarry, v: jax.Array, vnext: jax.Array|None=None) -> tuple[DelaySynapseCarry, jax.Array]: if vnext is None: assert len(carry) == 3 syn, ts, vprev = carry v, vnext = vprev, v else: assert len(carry) == 2 syn, ts = carry delay = self.delay_activation(self.param('delay', self.delay_init, v.shape)) isyn = syn.isyn syn = syn.timestep_spike_detect_pre(t_ms=self.dt*ts, v=v, vnext=vnext, delay_ms=delay) if len(carry) == 3: return (syn, ts+1, vnext), isyn else: return (syn, ts+1), isynPlain delayed synapse that detects spikes given input voltage
Ancestors
- flax.linen.module.Module
- flax.linen.module.ModuleBase
Class variables
var scope
Instance variables
var dt : floatvar max_delay : floatvar name : str | Nonevar parent : flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | Nonevar queue : type[BaseQueue]-
FIFORing(buffer, head, size)
var tau_syn1_ms : floatvar tau_syn2_ms : floatvar vthres : float
Methods
def delay_activation(self, x)-
Expand source code
delay_activation = lambda self, x: self.max_delay * (1 + nn.tanh(x)) def delay_init(key: Array,
shape: core.Shape,
dtype: DTypeLikeInexact | None = None,
out_sharding: OutShardingType = None) ‑> jax.Array-
Expand source code
def init(key: Array, shape: core.Shape, dtype: DTypeLikeInexact | None = dtype, out_sharding: OutShardingType = None) -> Array: dtype = dtypes.default_float_dtype() if dtype is None else dtype return random.normal(key, shape, dtype, out_sharding=out_sharding) * jnp.array(stddev, dtype) def init_carry(self, v: jax.Array, vnext: None | jax.Array = None) ‑> tuple[StaticMultiSynapse, int] | tuple[StaticMultiSynapse, int, jax.Array]-
Expand source code
def init_carry(self, v:jax.Array, vnext: jax.Array|None=None) -> DelaySynapseCarry: 'if vnext is none, we delay one timestep' assert len(v.shape) == 1 syn = synapse2.mk_synapse2s(self.queue, vthres=self.vthres, tau_syn1_ms=self.tau_syn1_ms, tau_syn2_ms=self.tau_syn2_ms, dt_ms=self.dt, n=len(v), max_delay_ms=self.max_delay) if vnext is None: return syn, 0, 0*v # type: ignore else: assert vnext.shape == v.shape return syn, 0 # type: ignoreif vnext is none, we delay one timestep
class Dense (dt: float,
nout: int | None = None,
weight_init: jax.nn.initializers.Initializer | Callable[..., typing.Any] = <function uniform.<locals>.init>,
delay_init: jax.nn.initializers.Initializer | Callable[..., typing.Any] = <function normal.<locals>.init>,
queue: type[BaseQueue] = adseq.implementations.fifo_ring.FIFORing[4],
parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>,
name: str | None = None)-
Expand source code
class Dense(nn.Module): dt: float nout: int | None = None weight_init: nn.initializers.Initializer= nn.initializers.uniform(1.5) delay_init: nn.initializers.Initializer = nn.initializers.normal(1.) queue: type[implementations.BaseQueue] = implementations.FIFORing.sized(4) # type: ignore def setup(self): self.model = Sequential([ Explode(self.nout), DelayedThresholdSynapse(self.dt, delay_init=self.delay_init, queue=self.queue), LTIReduce(self.nout, self.weight_init) ]) def init_carry(self, x): return self.model.init_carry(x) def __call__(self, carry, x): return self.model(carry, x)Dense(dt: float, nout: int | None = None, weight_init: Union[jax.nn.initializers.Initializer, collections.abc.Callable[…, Any]] =
.init at 0x737fc501e5c0>, delay_init: Union[jax.nn.initializers.Initializer, collections.abc.Callable[…, Any]] = .init at 0x737fc501e660>, queue: type[adseq.implementations.base.BaseQueue] = , parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = , name: Optional[str] = None) Ancestors
- flax.linen.module.Module
- flax.linen.module.ModuleBase
Class variables
var scope
Instance variables
var dt : floatvar name : str | Nonevar nout : int | Nonevar parent : flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | Nonevar queue : type[BaseQueue]-
FIFORing(buffer, head, size)
Methods
def delay_init(key: Array,
shape: core.Shape,
dtype: DTypeLikeInexact | None = None,
out_sharding: OutShardingType = None) ‑> jax.Array-
Expand source code
def init(key: Array, shape: core.Shape, dtype: DTypeLikeInexact | None = dtype, out_sharding: OutShardingType = None) -> Array: dtype = dtypes.default_float_dtype() if dtype is None else dtype return random.normal(key, shape, dtype, out_sharding=out_sharding) * jnp.array(stddev, dtype) def init_carry(self, x)-
Expand source code
def init_carry(self, x): return self.model.init_carry(x) def setup(self)-
Expand source code
def setup(self): self.model = Sequential([ Explode(self.nout), DelayedThresholdSynapse(self.dt, delay_init=self.delay_init, queue=self.queue), LTIReduce(self.nout, self.weight_init) ]) def weight_init(key: Array,
shape: core.Shape,
dtype: DTypeLikeInexact | None = None,
out_sharding: OutShardingType = None) ‑> jax.Array-
Expand source code
def init(key: Array, shape: core.Shape, dtype: DTypeLikeInexact | None = dtype, out_sharding: OutShardingType = None) -> Array: dtype = dtypes.default_float_dtype() if dtype is None else dtype return random.uniform(key, shape, dtype, out_sharding=out_sharding) * jnp.array(scale, dtype)
class DenseInput (dt: float,
nout: int | None = None,
weight_init: jax.nn.initializers.Initializer | Callable[..., typing.Any] = <function uniform.<locals>.init>,
delay_init: jax.nn.initializers.Initializer | Callable[..., typing.Any] = <function normal.<locals>.init>,
queue: type[BaseQueue] = adseq.implementations.fifo_ring.FIFORing[4],
parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>,
name: str | None = None)-
Expand source code
class DenseInput(nn.Module): dt: float nout: int | None = None weight_init: nn.initializers.Initializer= nn.initializers.uniform(1.5) delay_init: nn.initializers.Initializer = nn.initializers.normal(1.) queue: type[implementations.BaseQueue] = implementations.FIFORing.sized(4) # type: ignore def setup(self): self.model = Sequential([ Explode(self.nout), DelayedStaticSynapse(self.dt, delay_init=self.delay_init, queue=self.queue), LTIReduce(self.nout, self.weight_init) ]) def init_carry(self, x): return self.model.init_carry(x) def __call__(self, carry, x): return self.model(carry, x)DenseInput(dt: float, nout: int | None = None, weight_init: Union[jax.nn.initializers.Initializer, collections.abc.Callable[…, Any]] =
.init at 0x737fc501dc60>, delay_init: Union[jax.nn.initializers.Initializer, collections.abc.Callable[…, Any]] = .init at 0x737fc501dd00>, queue: type[adseq.implementations.base.BaseQueue] = , parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = , name: Optional[str] = None) Ancestors
- flax.linen.module.Module
- flax.linen.module.ModuleBase
Class variables
var scope
Instance variables
var dt : floatvar name : str | Nonevar nout : int | Nonevar parent : flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | Nonevar queue : type[BaseQueue]-
FIFORing(buffer, head, size)
Methods
def delay_init(key: Array,
shape: core.Shape,
dtype: DTypeLikeInexact | None = None,
out_sharding: OutShardingType = None) ‑> jax.Array-
Expand source code
def init(key: Array, shape: core.Shape, dtype: DTypeLikeInexact | None = dtype, out_sharding: OutShardingType = None) -> Array: dtype = dtypes.default_float_dtype() if dtype is None else dtype return random.normal(key, shape, dtype, out_sharding=out_sharding) * jnp.array(stddev, dtype) def init_carry(self, x)-
Expand source code
def init_carry(self, x): return self.model.init_carry(x) def setup(self)-
Expand source code
def setup(self): self.model = Sequential([ Explode(self.nout), DelayedStaticSynapse(self.dt, delay_init=self.delay_init, queue=self.queue), LTIReduce(self.nout, self.weight_init) ]) def weight_init(key: Array,
shape: core.Shape,
dtype: DTypeLikeInexact | None = None,
out_sharding: OutShardingType = None) ‑> jax.Array-
Expand source code
def init(key: Array, shape: core.Shape, dtype: DTypeLikeInexact | None = dtype, out_sharding: OutShardingType = None) -> Array: dtype = dtypes.default_float_dtype() if dtype is None else dtype return random.uniform(key, shape, dtype, out_sharding=out_sharding) * jnp.array(scale, dtype)
class Explode (nout: int | None = None,
parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>,
name: str | None = None)-
Expand source code
class Explode(nn.Module): 'HelperModule for Dense LTI synapses; duplicate across nout targets' nout: int | None = None @nn.compact def __call__(self, v: jax.Array): assert len(v.shape) == 1 nout = self.nout if self.nout is not None else v.shape[-1] return jnp.tile(v, nout)HelperModule for Dense LTI synapses; duplicate across nout targets
Ancestors
- flax.linen.module.Module
- flax.linen.module.ModuleBase
Class variables
var scope
Instance variables
var name : str | Nonevar nout : int | Nonevar parent : flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None
class LIF (dt: float,
tau_mem: float = 10.0,
vthres: float = 1.0,
reset_gradient: Literal['surrogate'] = 'surrogate',
output: Literal['voltage'] | Literal['single_spike'] | Literal['ttfs'] | Literal['superspike()'] = 'voltage',
parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>,
name: str | None = None)-
Expand source code
class LIF(nn.Module): dt: float tau_mem: float = 10. vthres: float = 1.0 reset_gradient: typing.Literal['surrogate'] = 'surrogate' output: typing.Literal['voltage'] | typing.Literal['single_spike'] | typing.Literal['ttfs'] | typing.Literal['superspike'] = 'voltage' def setup(self): assert self.reset_gradient == 'surrogate' self.model = SurrogateLIF(self.dt, self.tau_mem, self.vthres) if self.output == 'voltage': self.model_output = None elif self.output == 'superspike': self.model_output = SurrogateSpikeFilter(self.dt, self.vthres) elif self.output == 'single_spike': self.model_output = SingleSpikeFilter(self.dt, self.vthres) elif self.output == 'ttfs': self.model_output = TTFSFilter(self.dt, self.vthres) def init_carry(self, isyn): carry = self.model.init_carry(isyn) if self.model_output is None: return carry _carry, v = self.model(carry, isyn) carry_output = self.model_output.init_carry(v) return carry, carry_output def __call__(self, carry, isyn): if self.model_output is None: carry, out = self.model(carry, isyn) else: c0, c1 = carry c0, v = self.model(c0, isyn) c1, out = self.model_output(c1, v) carry = c0, c1 return carry, outLIF(dt: float, tau_mem: float = 10.0, vthres: float = 1.0, reset_gradient: Literal['surrogate'] = 'surrogate', output: Union[Literal['voltage'], Literal['single_spike'], Literal['ttfs'], Literal['superspike']] = 'voltage', parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] =
, name: Optional[str] = None) Ancestors
- flax.linen.module.Module
- flax.linen.module.ModuleBase
Class variables
var scope
Instance variables
var dt : floatvar name : str | Nonevar output : Literal['voltage'] | Literal['single_spike'] | Literal['ttfs'] | Literal['superspike()']var parent : flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | Nonevar reset_gradient : Literal['surrogate']var tau_mem : floatvar vthres : float
Methods
def init_carry(self, isyn)-
Expand source code
def init_carry(self, isyn): carry = self.model.init_carry(isyn) if self.model_output is None: return carry _carry, v = self.model(carry, isyn) carry_output = self.model_output.init_carry(v) return carry, carry_output def setup(self)-
Expand source code
def setup(self): assert self.reset_gradient == 'surrogate' self.model = SurrogateLIF(self.dt, self.tau_mem, self.vthres) if self.output == 'voltage': self.model_output = None elif self.output == 'superspike': self.model_output = SurrogateSpikeFilter(self.dt, self.vthres) elif self.output == 'single_spike': self.model_output = SingleSpikeFilter(self.dt, self.vthres) elif self.output == 'ttfs': self.model_output = TTFSFilter(self.dt, self.vthres)
class LTIReduce (nout: int | None = None,
weight_init: jax.nn.initializers.Initializer | Callable[..., typing.Any] = <function normal.<locals>.init>,
parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>,
name: str | None = None)-
Expand source code
class LTIReduce(nn.Module): 'HelperModule for Dense LTI synapses; weighted sum up to nout features' nout: int | None = None weight_init: nn.initializers.Initializer = nn.initializers.normal() @nn.compact def __call__(self, isyn: jax.Array): nsyn = isyn.shape[-1] if self.nout is not None: nin = nsyn // self.nout nout = self.nout else: nin = nout = int(nsyn ** 0.5) assert nsyn == nin * nout assert len(isyn.shape) == 1 weight = self.param('weight', self.weight_init, isyn.shape) # isyn = (isyn * weight).reshape(nout, nin) isyn = (isyn * weight).reshape(nout, nin) return isyn.sum(1) # second dimension if not batchedHelperModule for Dense LTI synapses; weighted sum up to nout features
Ancestors
- flax.linen.module.Module
- flax.linen.module.ModuleBase
Class variables
var scope
Instance variables
var name : str | Nonevar nout : int | Nonevar parent : flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None
Methods
def weight_init(key: Array,
shape: core.Shape,
dtype: DTypeLikeInexact | None = None,
out_sharding: OutShardingType = None) ‑> jax.Array-
Expand source code
def init(key: Array, shape: core.Shape, dtype: DTypeLikeInexact | None = dtype, out_sharding: OutShardingType = None) -> Array: dtype = dtypes.default_float_dtype() if dtype is None else dtype return random.normal(key, shape, dtype, out_sharding=out_sharding) * jnp.array(stddev, dtype)
class Sequential (layers: List[flax.linen.module.Module],
parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>,
name: str | None = None)-
Expand source code
class Sequential(nn.Module): layers: typing.List[nn.Module] def init_carry(self, x): cs = [] for layer in self.layers: if hasattr(layer, 'init_carry'): carry = layer.init_carry(x) _carry, x = layer(carry, x) else: carry = None x = layer(x) cs.append(carry) return cs def trace(self, xs, output_all=False): carry = self.init_carry(xs[0]) carry, ys = jax.lax.scan(lambda c, x: self.__call__(c, x, output_all), carry, xs) return ys def __call__(self, carry, x, output_all=False): if carry is None: carry = self.init_carry(x) carry_out = [] if output_all: output = [] for c, layer in zip(carry, self.layers): if c is None: x = layer(x) else: c, x = layer(c, x) carry_out.append(c) if output_all: output.append(x) if output_all: return carry_out, output return carry_out, xSequential(layers: List[flax.linen.module.Module], parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] =
, name: Optional[str] = None) Ancestors
- flax.linen.module.Module
- flax.linen.module.ModuleBase
Class variables
var scope
Instance variables
var layers : List[flax.linen.module.Module]var name : str | Nonevar parent : flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None
Methods
def init_carry(self, x)-
Expand source code
def init_carry(self, x): cs = [] for layer in self.layers: if hasattr(layer, 'init_carry'): carry = layer.init_carry(x) _carry, x = layer(carry, x) else: carry = None x = layer(x) cs.append(carry) return cs def trace(self, xs, output_all=False)-
Expand source code
def trace(self, xs, output_all=False): carry = self.init_carry(xs[0]) carry, ys = jax.lax.scan(lambda c, x: self.__call__(c, x, output_all), carry, xs) return ys
class SingleSpikeFilter (dt: float = None,
vthres: float = 1.0,
parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>,
name: str | None = None)-
Expand source code
class SingleSpikeFilter(nn.Module): 'Passthrough voltage until spike, then hold' dt: float = None vthres: float = 1.0 def init_carry(self, v:jax.Array, vnext: jax.Array|None=None) -> TTFSCarry: 'if vnext is none, we delay one timestep' assert len(v.shape) == 1 if vnext is None: return 0*v, 0, 0*v else: assert vnext.shape == v.shape return 0*v, 0 @nn.compact def __call__(self, carry: LIFCarry, v: jax.Array, vnext: jax.Array|None=None) -> tuple[TTFSCarry, jax.Array]: if vnext is None: assert len(carry) == 3 vhold, ts, vprev = carry v, vnext = vprev, v else: assert len(carry) == 2 vhold, ts = carry out = jnp.where(vhold >= self.vthres, vhold, v) vhold = jnp.where(vhold >= self.vthres, vhold, jax.lax.stop_gradient(v)) if len(carry) == 3: return (vhold, ts+1, vnext), out else: return (vhold, ts+1), outPassthrough voltage until spike, then hold
Ancestors
- flax.linen.module.Module
- flax.linen.module.ModuleBase
Class variables
var scope
Instance variables
var dt : floatvar name : str | Nonevar parent : flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | Nonevar vthres : float
Methods
def init_carry(self, v: jax.Array, vnext: None | jax.Array = None) ‑> tuple[jax.Array, int] | tuple[jax.Array, int, jax.Array]-
Expand source code
def init_carry(self, v:jax.Array, vnext: jax.Array|None=None) -> TTFSCarry: 'if vnext is none, we delay one timestep' assert len(v.shape) == 1 if vnext is None: return 0*v, 0, 0*v else: assert vnext.shape == v.shape return 0*v, 0if vnext is none, we delay one timestep
class StaticMultiSynapse-
Expand source code
class StaticMultiSynapse(abc.ABC): @property @abc.abstractmethod def isyn(self) -> jax.Array: ... @abc.abstractmethod def timestep_spike_detect_pre(self, t_ms, v, vnext, delay_ms) -> typing.Self: ... # type: ignore @abc.abstractmethod def timestep_static_spike(self, t_ms, s, delay_ms) -> typing.Self: ... # type: ignoreHelper class that provides a standard way to create an ABC using inheritance.
Ancestors
- abc.ABC
Instance variables
prop isyn : jax.Array-
Expand source code
@property @abc.abstractmethod def isyn(self) -> jax.Array: ...
Methods
def timestep_spike_detect_pre(self, t_ms, v, vnext, delay_ms) ‑> Self-
Expand source code
@abc.abstractmethod def timestep_spike_detect_pre(self, t_ms, v, vnext, delay_ms) -> typing.Self: ... # type: ignore def timestep_static_spike(self, t_ms, s, delay_ms) ‑> Self-
Expand source code
@abc.abstractmethod def timestep_static_spike(self, t_ms, s, delay_ms) -> typing.Self: ... # type: ignore
class SurrogateLIF (dt: float,
tau_mem: float = 10.0,
vthres: float = 1.0,
parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>,
name: str | None = None)-
Expand source code
class SurrogateLIF(nn.Module): dt: float tau_mem: float = 10. vthres: float = 1.0 def init_carry(self, isyn): return isyn*0 @nn.compact def __call__(self, carry: LIFCarry, isyn: jax.Array) -> tuple[LIFCarry, jax.Array]: v = carry S = superspike(v - self.vthres) beta = jnp.exp(-self.dt/self.tau_mem) v_next = (1 - S) * (beta * v + isyn*self.dt) return v_next, vSurrogateLIF(dt: float, tau_mem: float = 10.0, vthres: float = 1.0, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] =
, name: Optional[str] = None) Ancestors
- flax.linen.module.Module
- flax.linen.module.ModuleBase
Class variables
var scope
Instance variables
var dt : floatvar name : str | Nonevar parent : flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | Nonevar tau_mem : floatvar vthres : float
Methods
def init_carry(self, isyn)-
Expand source code
def init_carry(self, isyn): return isyn*0
class SurrogateSpikeFilter (dt: float = None,
vthres: float = 1.0,
parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>,
name: str | None = None)-
Expand source code
class SurrogateSpikeFilter(nn.Module): dt: float = None vthres: float = 1.0 def init_carry(self, v): return None def __call__(self, carry, v): S = superspike(v - self.vthres) return carry, SSurrogateSpikeFilter(dt: float = None, vthres: float = 1.0, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] =
, name: Optional[str] = None) Ancestors
- flax.linen.module.Module
- flax.linen.module.ModuleBase
Class variables
var scope
Instance variables
var dt : floatvar name : str | Nonevar parent : flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | Nonevar vthres : float
Methods
def init_carry(self, v)-
Expand source code
def init_carry(self, v): return None
class TTFSFilter (dt: float,
vthres: float = 1.0,
parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>,
name: str | None = None)-
Expand source code
class TTFSFilter(nn.Module): 'Receives voltages, outputs differentiable first spike time' dt: float vthres: float = 1.0 def init_carry(self, v:jax.Array, vnext: jax.Array|None=None) -> TTFSCarry: 'if vnext is none, we delay one timestep' assert len(v.shape) == 1 if vnext is None: return -1 + 0*v, 0, 0*v else: assert vnext.shape == v.shape return -1 + 0*v, 0 @nn.compact def __call__(self, carry: LIFCarry, v: jax.Array, vnext: jax.Array|None=None) -> tuple[TTFSCarry, jax.Array]: if vnext is None: assert len(carry) == 3 ttfs, ts, vprev = carry v, vnext = vprev, v else: assert len(carry) == 2 ttfs, ts = carry ttfs: jax.Array tpost = jax.vmap(synapse2.spike_detect, in_axes=[None,None,None,0,0,None])(self.dt, self.dt*ts, self.vthres, v, vnext, 0.) ttfs = jnp.where((ttfs != -1), jnp.where((tpost != -1), jnp.minimum(ttfs, tpost), ttfs), tpost) if len(carry) == 3: return (ttfs, ts+1, vnext), ttfs else: return (ttfs, ts+1), ttfsReceives voltages, outputs differentiable first spike time
Ancestors
- flax.linen.module.Module
- flax.linen.module.ModuleBase
Class variables
var scope
Instance variables
var dt : floatvar name : str | Nonevar parent : flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | Nonevar vthres : float
Methods
def init_carry(self, v: jax.Array, vnext: None | jax.Array = None) ‑> tuple[jax.Array, int] | tuple[jax.Array, int, jax.Array]-
Expand source code
def init_carry(self, v:jax.Array, vnext: jax.Array|None=None) -> TTFSCarry: 'if vnext is none, we delay one timestep' assert len(v.shape) == 1 if vnext is None: return -1 + 0*v, 0, 0*v else: assert vnext.shape == v.shape return -1 + 0*v, 0if vnext is none, we delay one timestep