Module adseq.benchmarks.profile_jaxley
Functions
def sim()-
Expand source code
def sim(): num_cells = 11 delays = jnp.array(2+3*np.random.random(num_cells*(num_cells-1))) weights = jnp.array(5*np.random.random(num_cells*(num_cells-1))) comp = jx.Compartment() branch = jx.Branch(comp, ncomp=4) cell = jx.Cell(branch, parents=[-1, 0, 0, 1, 1, 2, 2]) i_delay = 3.0 # ms i_amp = 0.05 # nA i_dur = 2.0 # ms dt = 0.025 # ms t_max = 50.0 # ms net = jx.Network([cell for _ in range(num_cells)]) pre = net.cell(range(num_cells)) post = net.cell(range(num_cells)) fully_connect(pre, post, DelaySynapse(), True) idx = np.arange(num_cells) net.select(edges=idx*num_cells+idx).set('DelaySynapse_delay', 0) net.select(edges=idx*num_cells+idx).set('DelaySynapse_weight', 0) nonself = np.array([i*num_cells + j for i in range(num_cells) for j in range(num_cells) if i != j]) net.select(edges=nonself).set('DelaySynapse_delay', delays) net.select(edges=nonself).set('DelaySynapse_weight', weights) net.select(edges=nonself).make_trainable('DelaySynapse_delay') # net.select(edges=nonself).make_trainable('DelaySynapse_weight') net.insert(Na()) net.insert(K()) net.insert(Leak()) current = jx.step_current(i_delay, i_dur, i_amp, dt, t_max) net.delete_stimuli() for stim_ind in range(5): net.cell(stim_ind).branch(0).loc(0.0).stimulate(current) net.delete_recordings() net.cell(range(11)).branch(0).loc(0.0).record() parameters = net.get_parameters() # Define parameter transform and apply it to the parameters. transform = jx.ParamTransform([ {'DelaySynapse_delay': jt.SigmoidTransform(.1, 20.0)}, # {'DelaySynapse_weight': jt.SigmoidTransform(0.0, 5.0)} ]) def loss(opt_params): params = transform.forward(opt_params) s = jx.integrate(net, delta_t=dt, params=params) n = s.shape[1] return s[:,:n//2].mean() # - s[:,n//2:].mean() / 100 opt_params = transform.inverse(parameters) optimizer = optax.adam(learning_rate=1) opt_state = optimizer.init(opt_params) g = jax.jit(jax.value_and_grad(loss, argnums=0)) OLD = transform.forward(opt_params) @jax.jit def step(opt_params, opt_state): loss, gradient = g(opt_params) updates, opt_state = optimizer.update(gradient, opt_state) opt_params = optax.apply_updates(opt_params, updates) return loss, opt_params, opt_state for i in range(20): loss, opt_params, opt_state = step(opt_params, opt_state) print(i, loss) # loss, gradient = g(opt_params) # updates, opt_state = optimizer.update(gradient, opt_state) # opt_params = optax.apply_updates(opt_params, updates) # print(loss) NEW = transform.forward(opt_params) #print(NEW[0]['DelaySynapse_delay'] - OLD[0]['DelaySynapse_delay']) s_old = jx.integrate(net, delta_t=dt, params=OLD) s_new = jx.integrate(net, delta_t=dt, params=NEW) print(NEW) plt.plot(s_old.T, color='black') plt.plot(s_new.T-100, color='black') plt.savefig('./img/delay_training.png') plt.show()
Classes
class DelaySynapse (name: str | None = None)-
Expand source code
class DelaySynapse(Synapse): def __init__(self, name: typing.Optional[str] = None): super().__init__(name) prefix = self._name # queue = implementations.FIFORing.sized(1).init(None, grad=True) queue = implementations.SingleSpike.init(None, grad=True) queue_list, self.struct = jax.tree.flatten(queue) self.struct_size = len(queue_list) self.synapse_params = { f'{prefix}_tau': 2., # ms f'{prefix}_delay': 10., # ms f'{prefix}_weight': 1., # ms } self.synapse_states: dict[str, typing.Any] = { f'{prefix}_queue{i}': queue[i] for i in range(self.struct_size) } self.synapse_states[f'{prefix}_vprev'] = 0. self.synapse_states[f'{prefix}_isyn'] = 0. self.synapse_states[f'{prefix}_ts'] = 0. def update_states( self, states: dict, delta_t: float, pre_voltage: float, post_voltage: float, params: dict, ) -> dict: prefix = self._name queues = self.struct.unflatten([ states[f'{prefix}_queue{i}'] for i in range(self.struct_size) ]) ts = states[f'{prefix}_ts'] vthres = 1.0 delay_ms = params[f'{prefix}_delay'] tau_syn_ms = params[f'{prefix}_tau'] def timestep(ts, queue, isyn, v, vnext, delay_ms, tau_syn_ms): alpha = jnp.exp(-delta_t / tau_syn_ms) # inefficient tpost = synapse.spike_detect(delta_t, ts, vthres, v, vnext, delay_ms) queue = jax.lax.cond(tpost != -1, # must be a better solution lambda: queue.enqueue(synapse.time_to_timestep_keep_gradient(tpost, delta_t)), # type: ignore lambda: queue) queue, post_hit = queue.pop(synapse.time_to_timestep_keep_gradient(ts, delta_t)) isyn = alpha * isyn + \ synapse.apply_recv_gradient(post_hit, tau_syn_ms) return (queue, isyn) vprev = states[f'{prefix}_vprev'] isyn = states[f'{prefix}_isyn'] queues, isyn = jax.vmap(timestep)( ts, #ugly queues, isyn, vprev, pre_voltage, delay_ms, tau_syn_ms) queue_parts, _struct = jax.tree.flatten(queues) state_out: dict[str, typing.Any] = { f'{prefix}_queue{i}': queue_parts[i] for i in range(self.struct_size) } state_out[f'{prefix}_vprev'] = pre_voltage state_out[f'{prefix}_isyn'] = isyn state_out[f'{prefix}_ts'] = states[f'{prefix}_ts'] + delta_t return state_out def compute_current( self, states: dict, pre_voltage: float, post_voltage: float, params: dict ) -> float: prefix = self._name return -0.01 * states[f'{prefix}_isyn'] * params[f'{prefix}_weight']Base class for a synapse.
As in NEURON, a
Synapseis considered a point process, which means that its conductances are to be specified inuSand its currents are to be specified innA.Ancestors
- jaxley.synapses.synapse.Synapse
Methods
def compute_current(self, states: dict, pre_voltage: float, post_voltage: float, params: dict) ‑> float-
Expand source code
def compute_current( self, states: dict, pre_voltage: float, post_voltage: float, params: dict ) -> float: prefix = self._name return -0.01 * states[f'{prefix}_isyn'] * params[f'{prefix}_weight']Return current through one synapse in
nA.Internally, we use
jax.vmapto vectorize this function across many synapses.Args
states- States of the synapse.
pre_voltage- Voltage of the presynaptic compartment, shape
(). post_voltage- Voltage of the postsynaptic compartment, shape
(). params- Parameters of the synapse. Conductances in
uS.
Returns
Current through the synapse in
nA, shape(). def update_states(self,
states: dict,
delta_t: float,
pre_voltage: float,
post_voltage: float,
params: dict) ‑> dict-
Expand source code
def update_states( self, states: dict, delta_t: float, pre_voltage: float, post_voltage: float, params: dict, ) -> dict: prefix = self._name queues = self.struct.unflatten([ states[f'{prefix}_queue{i}'] for i in range(self.struct_size) ]) ts = states[f'{prefix}_ts'] vthres = 1.0 delay_ms = params[f'{prefix}_delay'] tau_syn_ms = params[f'{prefix}_tau'] def timestep(ts, queue, isyn, v, vnext, delay_ms, tau_syn_ms): alpha = jnp.exp(-delta_t / tau_syn_ms) # inefficient tpost = synapse.spike_detect(delta_t, ts, vthres, v, vnext, delay_ms) queue = jax.lax.cond(tpost != -1, # must be a better solution lambda: queue.enqueue(synapse.time_to_timestep_keep_gradient(tpost, delta_t)), # type: ignore lambda: queue) queue, post_hit = queue.pop(synapse.time_to_timestep_keep_gradient(ts, delta_t)) isyn = alpha * isyn + \ synapse.apply_recv_gradient(post_hit, tau_syn_ms) return (queue, isyn) vprev = states[f'{prefix}_vprev'] isyn = states[f'{prefix}_isyn'] queues, isyn = jax.vmap(timestep)( ts, #ugly queues, isyn, vprev, pre_voltage, delay_ms, tau_syn_ms) queue_parts, _struct = jax.tree.flatten(queues) state_out: dict[str, typing.Any] = { f'{prefix}_queue{i}': queue_parts[i] for i in range(self.struct_size) } state_out[f'{prefix}_vprev'] = pre_voltage state_out[f'{prefix}_isyn'] = isyn state_out[f'{prefix}_ts'] = states[f'{prefix}_ts'] + delta_t return state_outODE update step.
Args
states- States of the synapse.
delta_t- Time step in
ms. pre_voltage- Voltage of the presynaptic compartment, shape
(). post_voltage- Voltage of the postsynaptic compartment, shape
(). params- Parameters of the synapse. Conductances in
uS.
Returns
Updated states.