Module adseq.benchmarks.profile_recurrent_snn
Functions
def benchmark_grad(jac=<function jacfwd>, n=100)-
Expand source code
def benchmark_grad(jac=jax.jacfwd, n=100): key = jax.random.PRNGKey(0) weight = jnp.sqrt(23)/jnp.sqrt(n) * 0.05 * zero_diagonal(jax.random.normal(key, (n,n)))**2 out = {} for Q in qs: try: t = jax.jit(jac(lambda w: sim(n, w, Q=Q)[1][1].sum())) runner = benchmarks.mkrunner(t, weight) deltas = [] for _ in range(5): a = time.time() o = runner() b = time.time() if not isinstance(o, Exception): deltas.append(b - a) tmean = jnp.mean(jnp.array(deltas)) tmean = float(tmean) / 10000 * 1e6 out[Q.__name__] = tmean print(Q.__name__.ljust(20), tmean, 'us') except Exception as ex: print(ex) return out def benchmark_regular(n=100)-
Expand source code
def benchmark_regular(n=100): key = jax.random.PRNGKey(0) weight = jnp.sqrt(23)/jnp.sqrt(n) * 0.05 * zero_diagonal(jax.random.normal(key, (n,n)))**2 out = {} for Q in qs: t = jax.jit(lambda w: sim(n, w, Q=Q)[1][1].sum()) runner = benchmarks.mkrunner(t, weight) deltas = [] for _ in range(5): a = time.time() runner() b = time.time() deltas.append(b - a) tmean = jnp.mean(jnp.array(deltas)) tmean = float(tmean) / 10000 * 1e6 out[Q.__name__] = tmean print(Q.__name__.ljust(20), tmean, 'us') return out def lif_step(U: jax.Array, I: jax.Array, tau_mem: float, dt: float, vth: float = 1)-
Expand source code
def lif_step(U: jax.Array, I: jax.Array, tau_mem: float, dt: float, vth: float =1): S = superspike(U - vth) beta = jnp.exp(-dt/tau_mem) U_next = (1 - S) * (beta * U + I*dt) return U_next, S def main()-
Expand source code
def main(): results = { 'regular': {}, 'forward': {}, 'reverse': {} } dev_name, results['host'] = benchmarks.get_device_id() # print('=== regular ===') # results['regular'].update(benchmark_regular()) print('=== forward ===') results['forward'].update(benchmark_grad(jax.jacfwd)) print('=== reverse ===') results['reverse'].update(benchmark_grad(jax.jacrev)) # with open(f'benchmarks/{dev_name}_grad.json', 'w') as f: # json.dump(results, f) def memory()-
Expand source code
def memory(): n = 100 key = jax.random.PRNGKey(0) weight = jnp.sqrt(23)/jnp.sqrt(n) * 0.05 * zero_diagonal(jax.random.normal(key, (n,n)))**2 for jac in [jax.jacrev, jax.jacfwd]: print(jac) for Q in qs: try: f = jax.jit(jac(lambda w: sim(n, w, Q=Q)[1][1].sum())) f = jax.jit(f).lower(weight).compile() b = sum(v for k, v in f.cost_analysis().items() if 'bytes' in k) except: pass print(Q.__name__.ljust(20), str(b).rjust(20)) def plot_sim(n)-
Expand source code
def plot_sim(n): 'plot_sim(32)' ts, (trace, spikes) = sim(n) plt.plot(ts, spikes.sum(1)) plt.ylim(0, 30) plt.show() plt.plot(ts, trace) plt.show() for i, neuron in enumerate(tqdm.tqdm(spikes.T)): idx, = jnp.where(neuron) plt.plot(idx, i * jnp.ones_like(idx), 'o') plt.show()plot_sim(32)
def sim(n=23,
weight: None | jax.Array = None,
Q=adseq.implementations.single_spike.SingleSpike)-
Expand source code
def sim( n = 23, weight: None | jax.Array = None, Q = implementations.SingleSpike ): dt = 0.025 tau_syn = 2. tau_mem = 10. vthres = 1.0 key = jax.random.PRNGKey(0) if weight is None: weight = jnp.sqrt(23)/jnp.sqrt(n) * 0.05 * zero_diagonal(jax.random.normal(key, (n,n)))**2 delays = 4 + .5*jax.random.normal(key, (n,n)).flatten() syn = synapse.mk_synapses(Q, # type: ignore delay_ms=delays, dt_ms=dt, vthres=vthres, tau_syn_ms=tau_syn, n=n*n, max_delay_ms=7 ) syn_step = jax.jit(type(syn).timestep_spike_detect_pre) v = jnp.zeros(n) state = v, syn def step(state, t): v, syn = state isyn = (weight @ jnp.roll(syn.isyn.reshape((n,n)).sum(0), 1)).at[:].add(1. * (t < 2)) # ring: # isyn = (weight * jnp.roll(syn.isyn, 1)).at[0].add(1. * (t < 2)) vnext, s = lif_step(v, isyn, tau_mem, dt, vthres) syn = syn_step(syn, ts=t, v=jnp.repeat(v, n), vnext=jnp.repeat(vnext, n)) state = vnext, syn return state, (v, s) ts = jnp.arange(10000) * dt _, trace = jax.lax.scan(step, state, xs=ts) return ts, trace def superspike(x)-
Expand source code
@jax.custom_jvp def superspike(x): 'doi.dx/10.1162/neco_a_01086' return jnp.where(x < 0, 0.0, 1.0)doi.dx/10.1162/neco_a_01086
def superspike_jvp(primals, tangents)-
Expand source code
@superspike.defjvp def superspike_jvp(primals, tangents): (x,), (x_dot,) = primals, tangents primal_out = jnp.where(x < 0, 0.0, 1.0) tangent_out = x_dot / (jnp.abs(x)+1)**2 return primal_out, tangent_out def zero_diagonal(arr)-
Expand source code
def zero_diagonal(arr): n = arr.shape[0] return arr * (1 - jnp.eye(n))