Module adseq.bridges.jaxley_bridge
Classes
class DelaySynapse (name: str | None = None,
Q=adseq.implementations.single_spike.SingleSpike,
vthres=10.0)-
Expand source code
class DelaySynapse(Synapse): def __init__(self, name: typing.Optional[str] = None, Q=SingleSpike, vthres=10.): super().__init__(name) prefix = self._name # queue = implementations.FIFORing.sized(1).init(None, grad=True) queue = Q.init(None, grad=True) queue_list, self.struct = jax.tree.flatten(queue) self.struct_size = len(queue_list) self.synapse_params = { f'{prefix}_tau1': 0.5, # ms f'{prefix}_tau2': 2, # ms f'{prefix}_delay': 20., # ms f'{prefix}_weight': 0.01, } self.synapse_states: dict[str, typing.Any] = { f'{prefix}_queue{i}': queue[i] for i in range(self.struct_size) } self.synapse_states[f'{prefix}_vprev'] = 0. self.synapse_states[f'{prefix}_isyn1'] = 0. self.synapse_states[f'{prefix}_isyn2'] = 0. self.synapse_states[f'{prefix}_ts'] = 0. self.vthres = vthres def update_states( self, states: dict, delta_t: float, pre_voltage: float, post_voltage: float, params: dict, ) -> dict: prefix = self._name queues = self.struct.unflatten([ states[f'{prefix}_queue{i}'] for i in range(self.struct_size) ]) ts = states[f'{prefix}_ts'] delay_ms = params[f'{prefix}_delay'] tau1_syn_ms = params[f'{prefix}_tau1'] tau2_syn_ms = params[f'{prefix}_tau2'] def timestep(ts, queue, isyn1, isyn2, v, vnext, delay_ms, tau1_syn_ms, tau2_syn_ms): alpha = jnp.exp(-delta_t / tau1_syn_ms) # inefficient beta = jnp.exp(-delta_t / tau2_syn_ms) # inefficient tpost = synapse2.spike_detect(delta_t, ts, self.vthres, v, vnext, delay_ms) queue = jax.lax.cond(tpost != -1, # must be a better solution lambda: queue.enqueue(synapse2.time_to_timestep_keep_gradient(tpost, delta_t)), # type: ignore lambda: queue) queue, post_hit = queue.pop(synapse2.time_to_timestep_keep_gradient(ts, delta_t)) jump1 = synapse2.apply_recv_gradient(post_hit, tau1_syn_ms) jump2 = synapse2.apply_recv_gradient(post_hit, tau2_syn_ms) isyn1 = alpha * isyn1 + jump1 isyn2 = beta * isyn2 + jump2 return (queue, isyn1, isyn2) vprev = states[f'{prefix}_vprev'] isyn1 = states[f'{prefix}_isyn1'] isyn2 = states[f'{prefix}_isyn2'] queues, isyn1, isyn2 = jax.vmap(timestep)( ts, #ugly queues, isyn1, isyn2, v=vprev, vnext=pre_voltage, delay_ms=delay_ms, tau1_syn_ms=tau1_syn_ms, tau2_syn_ms=tau2_syn_ms) queue_parts, _struct = jax.tree.flatten(queues) state_out: dict[str, typing.Any] = { f'{prefix}_queue{i}': queue_parts[i] for i in range(self.struct_size) } state_out[f'{prefix}_vprev'] = pre_voltage # jax.debug.print('{} {} {}', pre_voltage, states[f'{prefix}_vprev'], state_out[f'{prefix}_vprev']) state_out[f'{prefix}_isyn1'] = isyn1 state_out[f'{prefix}_isyn2'] = isyn2 state_out[f'{prefix}_ts'] = states[f'{prefix}_ts'] + delta_t return state_out def compute_current( self, states: dict, pre_voltage: float, post_voltage: float, params: dict ) -> float: prefix = self._name tau_syn1_ms = params[f'{prefix}_tau1'] tau_syn2_ms = params[f'{prefix}_tau2'] t_peak = (tau_syn1_ms * tau_syn2_ms / (tau_syn2_ms - tau_syn1_ms) * jnp.log(tau_syn2_ms / tau_syn1_ms)) denom = (jnp.exp(-t_peak / tau_syn2_ms) - jnp.exp(-t_peak / tau_syn1_ms)) return -(states[f'{prefix}_isyn2']-states[f'{prefix}_isyn1'])/denom * params[f'{prefix}_weight']Base class for a synapse.
As in NEURON, a
Synapseis considered a point process, which means that its conductances are to be specified inuSand its currents are to be specified innA.Ancestors
- jaxley.synapses.synapse.Synapse
Methods
def compute_current(self, states: dict, pre_voltage: float, post_voltage: float, params: dict) ‑> float-
Expand source code
def compute_current( self, states: dict, pre_voltage: float, post_voltage: float, params: dict ) -> float: prefix = self._name tau_syn1_ms = params[f'{prefix}_tau1'] tau_syn2_ms = params[f'{prefix}_tau2'] t_peak = (tau_syn1_ms * tau_syn2_ms / (tau_syn2_ms - tau_syn1_ms) * jnp.log(tau_syn2_ms / tau_syn1_ms)) denom = (jnp.exp(-t_peak / tau_syn2_ms) - jnp.exp(-t_peak / tau_syn1_ms)) return -(states[f'{prefix}_isyn2']-states[f'{prefix}_isyn1'])/denom * params[f'{prefix}_weight']Return current through one synapse in
nA.Internally, we use
jax.vmapto vectorize this function across many synapses.Args
states- States of the synapse.
pre_voltage- Voltage of the presynaptic compartment, shape
(). post_voltage- Voltage of the postsynaptic compartment, shape
(). params- Parameters of the synapse. Conductances in
uS.
Returns
Current through the synapse in
nA, shape(). def update_states(self,
states: dict,
delta_t: float,
pre_voltage: float,
post_voltage: float,
params: dict) ‑> dict-
Expand source code
def update_states( self, states: dict, delta_t: float, pre_voltage: float, post_voltage: float, params: dict, ) -> dict: prefix = self._name queues = self.struct.unflatten([ states[f'{prefix}_queue{i}'] for i in range(self.struct_size) ]) ts = states[f'{prefix}_ts'] delay_ms = params[f'{prefix}_delay'] tau1_syn_ms = params[f'{prefix}_tau1'] tau2_syn_ms = params[f'{prefix}_tau2'] def timestep(ts, queue, isyn1, isyn2, v, vnext, delay_ms, tau1_syn_ms, tau2_syn_ms): alpha = jnp.exp(-delta_t / tau1_syn_ms) # inefficient beta = jnp.exp(-delta_t / tau2_syn_ms) # inefficient tpost = synapse2.spike_detect(delta_t, ts, self.vthres, v, vnext, delay_ms) queue = jax.lax.cond(tpost != -1, # must be a better solution lambda: queue.enqueue(synapse2.time_to_timestep_keep_gradient(tpost, delta_t)), # type: ignore lambda: queue) queue, post_hit = queue.pop(synapse2.time_to_timestep_keep_gradient(ts, delta_t)) jump1 = synapse2.apply_recv_gradient(post_hit, tau1_syn_ms) jump2 = synapse2.apply_recv_gradient(post_hit, tau2_syn_ms) isyn1 = alpha * isyn1 + jump1 isyn2 = beta * isyn2 + jump2 return (queue, isyn1, isyn2) vprev = states[f'{prefix}_vprev'] isyn1 = states[f'{prefix}_isyn1'] isyn2 = states[f'{prefix}_isyn2'] queues, isyn1, isyn2 = jax.vmap(timestep)( ts, #ugly queues, isyn1, isyn2, v=vprev, vnext=pre_voltage, delay_ms=delay_ms, tau1_syn_ms=tau1_syn_ms, tau2_syn_ms=tau2_syn_ms) queue_parts, _struct = jax.tree.flatten(queues) state_out: dict[str, typing.Any] = { f'{prefix}_queue{i}': queue_parts[i] for i in range(self.struct_size) } state_out[f'{prefix}_vprev'] = pre_voltage # jax.debug.print('{} {} {}', pre_voltage, states[f'{prefix}_vprev'], state_out[f'{prefix}_vprev']) state_out[f'{prefix}_isyn1'] = isyn1 state_out[f'{prefix}_isyn2'] = isyn2 state_out[f'{prefix}_ts'] = states[f'{prefix}_ts'] + delta_t return state_outODE update step.
Args
states- States of the synapse.
delta_t- Time step in
ms. pre_voltage- Voltage of the presynaptic compartment, shape
(). post_voltage- Voltage of the postsynaptic compartment, shape
(). params- Parameters of the synapse. Conductances in
uS.
Returns
Updated states.
class FIFORing (buffer: jax.Array, head: int | jax.Array, size: int | jax.Array)-
Expand source code
class FIFORing(typing.NamedTuple): buffer: jax.Array head: int | jax.Array size: int | jax.Array @classmethod def init(cls, delay, capacity=None, grad=False): return cls( jnp.full(delay if capacity is None else capacity, INT_MAX, floatx if grad else 'int32'), 0, 0 ) @classmethod def sized(cls, n): "wish I could use __class_getitem__" return type(f'{cls.__name__}[{n}]', cls.__bases__, {**cls.__dict__, "init": functools.partial(cls.init, capacity=n)}) def enqueue(self, n): return _enqueue(self, n) def pop(self, n): return _pop(self, n)FIFORing(buffer, head, size)
Ancestors
- builtins.tuple
Static methods
def init(delay, capacity=None, grad=False)def sized(n)-
wish I could use class_getitem
Instance variables
var buffer : jax.Array-
Expand source code
class FIFORing(typing.NamedTuple): buffer: jax.Array head: int | jax.Array size: int | jax.Array @classmethod def init(cls, delay, capacity=None, grad=False): return cls( jnp.full(delay if capacity is None else capacity, INT_MAX, floatx if grad else 'int32'), 0, 0 ) @classmethod def sized(cls, n): "wish I could use __class_getitem__" return type(f'{cls.__name__}[{n}]', cls.__bases__, {**cls.__dict__, "init": functools.partial(cls.init, capacity=n)}) def enqueue(self, n): return _enqueue(self, n) def pop(self, n): return _pop(self, n)Alias for field number 0
var head : int | jax.Array-
Expand source code
class FIFORing(typing.NamedTuple): buffer: jax.Array head: int | jax.Array size: int | jax.Array @classmethod def init(cls, delay, capacity=None, grad=False): return cls( jnp.full(delay if capacity is None else capacity, INT_MAX, floatx if grad else 'int32'), 0, 0 ) @classmethod def sized(cls, n): "wish I could use __class_getitem__" return type(f'{cls.__name__}[{n}]', cls.__bases__, {**cls.__dict__, "init": functools.partial(cls.init, capacity=n)}) def enqueue(self, n): return _enqueue(self, n) def pop(self, n): return _pop(self, n)Alias for field number 1
var size : int | jax.Array-
Expand source code
class FIFORing(typing.NamedTuple): buffer: jax.Array head: int | jax.Array size: int | jax.Array @classmethod def init(cls, delay, capacity=None, grad=False): return cls( jnp.full(delay if capacity is None else capacity, INT_MAX, floatx if grad else 'int32'), 0, 0 ) @classmethod def sized(cls, n): "wish I could use __class_getitem__" return type(f'{cls.__name__}[{n}]', cls.__bases__, {**cls.__dict__, "init": functools.partial(cls.init, capacity=n)}) def enqueue(self, n): return _enqueue(self, n) def pop(self, n): return _pop(self, n)Alias for field number 2
Methods
def enqueue(self, n)-
Expand source code
def enqueue(self, n): return _enqueue(self, n) def pop(self, n)-
Expand source code
def pop(self, n): return _pop(self, n)
class SingleSpike (last_spike: jax.Array)-
Expand source code
class SingleSpike(typing.NamedTuple): last_spike: jax.Array @classmethod def init(cls, delay, grad=False): del delay return cls(jnp.array(INT_MAX if not grad else float(INT_MAX))) def enqueue(self, n): return _enqueue(self, n) def pop(self, n): return _pop(self, n)SingleSpike(last_spike,)
Ancestors
- builtins.tuple
Static methods
def init(delay, grad=False)
Instance variables
var last_spike : jax.Array-
Expand source code
class SingleSpike(typing.NamedTuple): last_spike: jax.Array @classmethod def init(cls, delay, grad=False): del delay return cls(jnp.array(INT_MAX if not grad else float(INT_MAX))) def enqueue(self, n): return _enqueue(self, n) def pop(self, n): return _pop(self, n)Alias for field number 0
Methods
def enqueue(self, n)-
Expand source code
def enqueue(self, n): return _enqueue(self, n) def pop(self, n)-
Expand source code
def pop(self, n): return _pop(self, n)