Module adseq.benchmarks.profile_jaxley

Functions

def sim()
Expand source code
def sim():
    num_cells = 11
    delays = jnp.array(2+3*np.random.random(num_cells*(num_cells-1)))
    weights = jnp.array(5*np.random.random(num_cells*(num_cells-1)))
    comp = jx.Compartment()
    branch = jx.Branch(comp, ncomp=4)
    cell = jx.Cell(branch, parents=[-1, 0, 0, 1, 1, 2, 2])
    i_delay = 3.0  # ms
    i_amp = 0.05  # nA
    i_dur = 2.0  # ms
    dt = 0.025  # ms
    t_max = 50.0  # ms
    net = jx.Network([cell for _ in range(num_cells)])
    pre = net.cell(range(num_cells))
    post = net.cell(range(num_cells))
    fully_connect(pre, post, DelaySynapse(), True)
    idx = np.arange(num_cells)
    net.select(edges=idx*num_cells+idx).set('DelaySynapse_delay', 0)
    net.select(edges=idx*num_cells+idx).set('DelaySynapse_weight', 0)
    nonself = np.array([i*num_cells + j for i in range(num_cells) for j in range(num_cells) if i != j])
    net.select(edges=nonself).set('DelaySynapse_delay', delays)
    net.select(edges=nonself).set('DelaySynapse_weight', weights)
    net.select(edges=nonself).make_trainable('DelaySynapse_delay')
    # net.select(edges=nonself).make_trainable('DelaySynapse_weight')
    net.insert(Na())
    net.insert(K())
    net.insert(Leak())
    current = jx.step_current(i_delay, i_dur, i_amp, dt, t_max)
    net.delete_stimuli()
    for stim_ind in range(5):
        net.cell(stim_ind).branch(0).loc(0.0).stimulate(current)
    net.delete_recordings()
    net.cell(range(11)).branch(0).loc(0.0).record()
    parameters = net.get_parameters()
    # Define parameter transform and apply it to the parameters.
    transform = jx.ParamTransform([
        {'DelaySynapse_delay':  jt.SigmoidTransform(.1, 20.0)},
        # {'DelaySynapse_weight': jt.SigmoidTransform(0.0, 5.0)}
    ])
    def loss(opt_params):
        params = transform.forward(opt_params)
        s = jx.integrate(net, delta_t=dt, params=params)
        n = s.shape[1]
        return s[:,:n//2].mean() # - s[:,n//2:].mean() / 100
    opt_params = transform.inverse(parameters)
    optimizer = optax.adam(learning_rate=1)
    opt_state = optimizer.init(opt_params)
    g = jax.jit(jax.value_and_grad(loss, argnums=0))
    OLD = transform.forward(opt_params)
    @jax.jit
    def step(opt_params, opt_state):
        loss, gradient = g(opt_params)
        updates, opt_state = optimizer.update(gradient, opt_state)
        opt_params = optax.apply_updates(opt_params, updates)
        return loss, opt_params, opt_state
    for i in range(20):
        loss, opt_params, opt_state = step(opt_params, opt_state)
        print(i, loss)
        # loss, gradient = g(opt_params)
        # updates, opt_state = optimizer.update(gradient, opt_state)
        # opt_params = optax.apply_updates(opt_params, updates)
        # print(loss)
    NEW = transform.forward(opt_params)
    #print(NEW[0]['DelaySynapse_delay'] - OLD[0]['DelaySynapse_delay'])
    s_old = jx.integrate(net, delta_t=dt, params=OLD)
    s_new = jx.integrate(net, delta_t=dt, params=NEW)
    print(NEW)
    plt.plot(s_old.T, color='black')
    plt.plot(s_new.T-100, color='black')
    plt.savefig('./img/delay_training.png')
    plt.show()

Classes

class DelaySynapse (name: str | None = None)
Expand source code
class DelaySynapse(Synapse):
    def __init__(self, name: typing.Optional[str] = None):
        super().__init__(name)
        prefix = self._name
        # queue = implementations.FIFORing.sized(1).init(None, grad=True)
        queue = implementations.SingleSpike.init(None, grad=True)
        queue_list, self.struct = jax.tree.flatten(queue)
        self.struct_size = len(queue_list)
        self.synapse_params = {
            f'{prefix}_tau': 2.,  # ms
            f'{prefix}_delay': 10.,  # ms
            f'{prefix}_weight': 1.,  # ms
        }
        self.synapse_states: dict[str, typing.Any] = {
            f'{prefix}_queue{i}': queue[i]
            for i in range(self.struct_size)
        }
        self.synapse_states[f'{prefix}_vprev'] = 0.
        self.synapse_states[f'{prefix}_isyn'] = 0.
        self.synapse_states[f'{prefix}_ts'] = 0.

    def update_states(
        self,
        states: dict,
        delta_t: float,
        pre_voltage: float,
        post_voltage: float,
        params: dict,
    ) -> dict:
        prefix = self._name
        queues = self.struct.unflatten([
            states[f'{prefix}_queue{i}']
            for i in range(self.struct_size)
            ])
        ts = states[f'{prefix}_ts']
        vthres = 1.0
        delay_ms = params[f'{prefix}_delay']
        tau_syn_ms = params[f'{prefix}_tau']
        def timestep(ts, queue, isyn, v, vnext, delay_ms, tau_syn_ms):
            alpha = jnp.exp(-delta_t / tau_syn_ms) # inefficient
            tpost = synapse.spike_detect(delta_t, ts, vthres, v, vnext, delay_ms)
            queue = jax.lax.cond(tpost != -1, # must be a better solution
                 lambda: queue.enqueue(synapse.time_to_timestep_keep_gradient(tpost, delta_t)), # type: ignore
                 lambda: queue)
            queue, post_hit = queue.pop(synapse.time_to_timestep_keep_gradient(ts, delta_t))
            isyn = alpha * isyn + \
                   synapse.apply_recv_gradient(post_hit, tau_syn_ms)
            return (queue, isyn)
        vprev = states[f'{prefix}_vprev']
        isyn = states[f'{prefix}_isyn']
        queues, isyn = jax.vmap(timestep)(
                ts, #ugly
                queues, isyn,
                vprev, pre_voltage, delay_ms, tau_syn_ms)
        queue_parts, _struct = jax.tree.flatten(queues)
        state_out: dict[str, typing.Any] = {
            f'{prefix}_queue{i}': queue_parts[i]
            for i in range(self.struct_size)
        }
        state_out[f'{prefix}_vprev'] = pre_voltage
        state_out[f'{prefix}_isyn'] = isyn
        state_out[f'{prefix}_ts'] = states[f'{prefix}_ts'] + delta_t
        return state_out

    def compute_current(
        self, states: dict, pre_voltage: float, post_voltage: float, params: dict
    ) -> float:
        prefix = self._name
        return -0.01 * states[f'{prefix}_isyn'] * params[f'{prefix}_weight']

Base class for a synapse.

As in NEURON, a Synapse is considered a point process, which means that its conductances are to be specified in uS and its currents are to be specified in nA.

Ancestors

  • jaxley.synapses.synapse.Synapse

Methods

def compute_current(self, states: dict, pre_voltage: float, post_voltage: float, params: dict) ‑> float
Expand source code
def compute_current(
    self, states: dict, pre_voltage: float, post_voltage: float, params: dict
) -> float:
    prefix = self._name
    return -0.01 * states[f'{prefix}_isyn'] * params[f'{prefix}_weight']

Return current through one synapse in nA.

Internally, we use jax.vmap to vectorize this function across many synapses.

Args

states
States of the synapse.
pre_voltage
Voltage of the presynaptic compartment, shape ().
post_voltage
Voltage of the postsynaptic compartment, shape ().
params
Parameters of the synapse. Conductances in uS.

Returns

Current through the synapse in nA, shape ().

def update_states(self,
states: dict,
delta_t: float,
pre_voltage: float,
post_voltage: float,
params: dict) ‑> dict
Expand source code
def update_states(
    self,
    states: dict,
    delta_t: float,
    pre_voltage: float,
    post_voltage: float,
    params: dict,
) -> dict:
    prefix = self._name
    queues = self.struct.unflatten([
        states[f'{prefix}_queue{i}']
        for i in range(self.struct_size)
        ])
    ts = states[f'{prefix}_ts']
    vthres = 1.0
    delay_ms = params[f'{prefix}_delay']
    tau_syn_ms = params[f'{prefix}_tau']
    def timestep(ts, queue, isyn, v, vnext, delay_ms, tau_syn_ms):
        alpha = jnp.exp(-delta_t / tau_syn_ms) # inefficient
        tpost = synapse.spike_detect(delta_t, ts, vthres, v, vnext, delay_ms)
        queue = jax.lax.cond(tpost != -1, # must be a better solution
             lambda: queue.enqueue(synapse.time_to_timestep_keep_gradient(tpost, delta_t)), # type: ignore
             lambda: queue)
        queue, post_hit = queue.pop(synapse.time_to_timestep_keep_gradient(ts, delta_t))
        isyn = alpha * isyn + \
               synapse.apply_recv_gradient(post_hit, tau_syn_ms)
        return (queue, isyn)
    vprev = states[f'{prefix}_vprev']
    isyn = states[f'{prefix}_isyn']
    queues, isyn = jax.vmap(timestep)(
            ts, #ugly
            queues, isyn,
            vprev, pre_voltage, delay_ms, tau_syn_ms)
    queue_parts, _struct = jax.tree.flatten(queues)
    state_out: dict[str, typing.Any] = {
        f'{prefix}_queue{i}': queue_parts[i]
        for i in range(self.struct_size)
    }
    state_out[f'{prefix}_vprev'] = pre_voltage
    state_out[f'{prefix}_isyn'] = isyn
    state_out[f'{prefix}_ts'] = states[f'{prefix}_ts'] + delta_t
    return state_out

ODE update step.

Args

states
States of the synapse.
delta_t
Time step in ms.
pre_voltage
Voltage of the presynaptic compartment, shape ().
post_voltage
Voltage of the postsynaptic compartment, shape ().
params
Parameters of the synapse. Conductances in uS.

Returns

Updated states.