Module adseq.benchmarks.profile_lossy_queues
Functions
def make_plot(imp, ax)-
Expand source code
def make_plot(imp, ax): for lam in [0.01]: #, 0.05, 0.1: # for lam in [0.1]: dly = jnp.arange(1, int(round(0.5/lam)), 0.1) o = [] key = jax.random.PRNGKey(0) f = jax.jit(lambda delay, key: jax.vmap(lambda delay: spike_drop(imp, lam=lam, delay=delay, Nevents=100, drop=20, key=key))(delay)) for i in range(10): key, k = jax.random.split(key) out = f(dly, k) o.append(out) o = jnp.mean(jnp.vstack(o), axis=0) ax.plot(dly*lam, o*100, label=f'$\\lambda$={lam}') ax.set_title(imp.__name__) # ax.legend() ax.set_ylabel('Drop rate (%)') ax.set_xlabel(r'$d \lambda$') ax.set_ylim(0, 50) def mkev(lam: float, Nevents: int, key=Array([0, 0], dtype=uint32))-
Expand source code
def mkev(lam: float, Nevents: int, key = jax.random.PRNGKey(0)): T = int(1/lam * Nevents) counts = jax.random.poisson(key, lam=lam, shape=(T,)) dropped = ((counts > 1) * counts).sum() jax.debug.print('generator lam={} dropped={} N={}', lam, dropped, Nevents) return (counts > 0).astype(jnp.int32) def spike_drop(QueueT, lam=10, delay=1, Nevents=100, drop=20, key=Array([0, 0], dtype=uint32))-
Expand source code
def spike_drop(QueueT, lam=10, delay=1, Nevents=100, drop=20, key=jax.random.PRNGKey(0)): stream = mkev(lam, Nevents, key=key) @jax.jit def f_loop(queue, arg): t, ev = arg queue, out = queue.pop(t) queue = jax.lax.cond(ev, lambda: queue.enqueue(t + delay), lambda: queue) return queue, out _, trace = jax.lax.scan(f_loop, QueueT.init(delay), xs=(jnp.arange(len(stream)), stream)) expected = jnp.roll(stream, delay).at[-drop:].set(False) # assert trace.max() == 1 got = trace.astype(bool) got = got.at[-drop:].set(False) drop = expected.sum() - got.sum() # assert (1-expected[got]).sum() == 0 # No FP return drop / expected.sum()