Module adseq.benchmarks.loss2
Functions
def go()-
Expand source code
def go(): lam = 0.001 size = 1 thres = 0.1 dly = jnp.arange(1, 2*100+1) while size < 100: imp = implementations.FIFORing.sized(size) key = jax.random.PRNGKey(0) f = jax.jit(lambda delay, key: jax.vmap(lambda delay: spike_drop(imp, lam=lam, delay=delay, Nevents=100, drop=20, key=key))(delay)) while True: droprate = [] for i in range(3): key, k = jax.random.split(key) out = f(dly, k) droprate.append(out) droprate = jnp.mean(jnp.vstack(droprate), axis=0) #print(dly) #print(droprate) if droprate.min() > thres: print('ERROR!!!!!') breakpoint() if droprate.max() < thres: dly += 200 else: print(droprate) E = jnp.interp(thres, droprate, lam*dly.astype('float32')) break print(size, E) size += 1 def mkev(lam: float, Nevents: int, key=Array([0, 0], dtype=uint32))-
Expand source code
def mkev(lam: float, Nevents: int, key = jax.random.PRNGKey(0)): T = int(1/lam * Nevents) counts = jax.random.poisson(key, lam=lam, shape=(T,)) dropped = ((counts > 1) * counts).sum() return (counts > 0).astype(jnp.int32) def spike_drop(QueueT, lam=10, delay=1, Nevents=100, drop=200, key=Array([0, 0], dtype=uint32))-
Expand source code
def spike_drop(QueueT, lam=10, delay=1, Nevents=100, drop=200, key=jax.random.PRNGKey(0)): stream = mkev(lam, Nevents, key=key) @jax.jit def f_loop(queue, arg): t, ev = arg queue, out = queue.pop(t) queue = jax.lax.cond(ev, lambda: queue.enqueue(t + delay), lambda: queue) return queue, out _, trace = jax.lax.scan(f_loop, QueueT.init(delay), xs=(jnp.arange(len(stream)), stream)) expected = jnp.roll(stream, delay).at[-drop:].set(False) # assert trace.max() == 1 got = trace.astype(bool) got = got.at[-drop:].set(False) drop = expected.sum() - got.sum() # assert (1-expected[got]).sum() == 0 # No FP return drop / expected.sum()