Module adseq.benchmarks.profile_lossy_queues

Functions

def make_plot(imp, ax)
Expand source code
def make_plot(imp, ax):
    for lam in [0.01]: #, 0.05, 0.1:
    # for lam in [0.1]:
        dly = jnp.arange(1, int(round(0.5/lam)), 0.1)
        o = []
        key = jax.random.PRNGKey(0)
        f = jax.jit(lambda delay, key: jax.vmap(lambda delay: spike_drop(imp, lam=lam, delay=delay, Nevents=100, drop=20, key=key))(delay))
        for i in range(10):
            key, k = jax.random.split(key)
            out = f(dly, k)
            o.append(out)
        o = jnp.mean(jnp.vstack(o), axis=0)
        ax.plot(dly*lam, o*100, label=f'$\\lambda$={lam}')
    ax.set_title(imp.__name__)
    # ax.legend()
    ax.set_ylabel('Drop rate (%)')
    ax.set_xlabel(r'$d \lambda$')
    ax.set_ylim(0, 50)
def mkev(lam: float, Nevents: int, key=Array([0, 0], dtype=uint32))
Expand source code
def mkev(lam: float, Nevents: int, key = jax.random.PRNGKey(0)):
    T = int(1/lam * Nevents)
    counts = jax.random.poisson(key, lam=lam, shape=(T,))
    dropped = ((counts > 1) * counts).sum()
    jax.debug.print('generator lam={} dropped={} N={}', lam, dropped, Nevents)
    return (counts > 0).astype(jnp.int32)
def spike_drop(QueueT, lam=10, delay=1, Nevents=100, drop=20, key=Array([0, 0], dtype=uint32))
Expand source code
def spike_drop(QueueT, lam=10, delay=1, Nevents=100, drop=20, key=jax.random.PRNGKey(0)):
    stream = mkev(lam, Nevents, key=key)
    @jax.jit
    def f_loop(queue, arg):
        t, ev = arg
        queue, out = queue.pop(t)
        queue = jax.lax.cond(ev, lambda: queue.enqueue(t + delay), lambda: queue)
        return queue, out
    _, trace = jax.lax.scan(f_loop, QueueT.init(delay), xs=(jnp.arange(len(stream)), stream))
    expected = jnp.roll(stream, delay).at[-drop:].set(False)
    # assert trace.max() == 1
    got = trace.astype(bool)
    got = got.at[-drop:].set(False)
    drop = expected.sum() - got.sum()
    # assert (1-expected[got]).sum() == 0 # No FP
    return drop / expected.sum()