Module adseq.benchmarks.loss2

Functions

def go()
Expand source code
def go():
    lam = 0.001
    size = 1
    thres = 0.1

    dly = jnp.arange(1, 2*100+1)
    while size < 100:
        imp = implementations.FIFORing.sized(size)

        key = jax.random.PRNGKey(0)
        f = jax.jit(lambda delay, key: jax.vmap(lambda delay: spike_drop(imp, lam=lam, delay=delay, Nevents=100, drop=20, key=key))(delay))

        while True:
            droprate = []
            for i in range(3):
                key, k = jax.random.split(key)
                out = f(dly, k)
                droprate.append(out)
            droprate = jnp.mean(jnp.vstack(droprate), axis=0)


            #print(dly)
            #print(droprate)
            if droprate.min() > thres:
                print('ERROR!!!!!')
                breakpoint()
            if droprate.max() < thres:
                dly += 200
            else:
                print(droprate)
                E = jnp.interp(thres, droprate, lam*dly.astype('float32'))
                break

        print(size, E)
        size += 1
def mkev(lam: float, Nevents: int, key=Array([0, 0], dtype=uint32))
Expand source code
def mkev(lam: float, Nevents: int, key = jax.random.PRNGKey(0)):
    T = int(1/lam * Nevents)
    counts = jax.random.poisson(key, lam=lam, shape=(T,))
    dropped = ((counts > 1) * counts).sum()
    return (counts > 0).astype(jnp.int32)
def spike_drop(QueueT, lam=10, delay=1, Nevents=100, drop=200, key=Array([0, 0], dtype=uint32))
Expand source code
def spike_drop(QueueT, lam=10, delay=1, Nevents=100, drop=200, key=jax.random.PRNGKey(0)):
    stream = mkev(lam, Nevents, key=key)
    @jax.jit
    def f_loop(queue, arg):
        t, ev = arg
        queue, out = queue.pop(t)
        queue = jax.lax.cond(ev, lambda: queue.enqueue(t + delay), lambda: queue)
        return queue, out
    _, trace = jax.lax.scan(f_loop, QueueT.init(delay), xs=(jnp.arange(len(stream)), stream))
    expected = jnp.roll(stream, delay).at[-drop:].set(False)
    # assert trace.max() == 1
    got = trace.astype(bool)
    got = got.at[-drop:].set(False)
    drop = expected.sum() - got.sum()
    # assert (1-expected[got]).sum() == 0 # No FP
    return drop / expected.sum()