Module adseq.benchmarks.profile_poisson
Functions
def mkev(lam: float, Nevents: int, key=Array([0, 0], dtype=uint32))-
Expand source code
def mkev(lam: float, Nevents: int, key=jax.random.PRNGKey(0)): Ntimesteps = lam * Nevents event_stream = jnp.zeros((Ntimesteps, ), dtype=bool) ts = jnp.round(jnp.cumulative_sum( jax.random.poisson(key, lam, (Nevents,)) )).astype(int) return event_stream.at[ts].set(True) def mkevs(lam, Nevents, num, key=Array([0, 0], dtype=uint32))-
Expand source code
def mkevs(lam, Nevents, num, key=jax.random.PRNGKey(0)): keys = jax.random.split(key, num) return jax.vmap(lambda k: mkev(lam, Nevents, k))(keys) def run()-
Expand source code
def run(): results = { 'single': {}, 'batched': {} } dev_name, results['host'] = benchmarks.get_device_id() print('Single') times = [] finished = [] for imp in tqdm.tqdm(check): print('###', imp.__name__) try: times.append(time_queue_single(imp)) finished.append(imp) except Exception as ex: print(repr(ex)) assert len(times) == len(finished) for t, imp in sorted(zip(times, finished)): print(imp.__name__.ljust(20), f'{t: 10.7f}us/ts') results['single'][str(imp.__name__)] = float(t) print() print('Batched') times = [] finished = [] for imp in (bar := tqdm.tqdm(check)): print('###', imp.__name__) try: t = time_queue_batched(imp) print(' (prelim)', imp.__name__.ljust(20), f'{t: 10.7f}us/ts') times.append(t) finished.append(imp) except Exception as ex: print(repr(ex)) assert len(times) == len(finished) for t, imp in sorted(zip(times, finished), key=lambda x:x[0]): print(imp.__name__.ljust(20), f'{t: 10.7f}us/ts') results['batched'][str(imp.__name__)] = float(t) with open(f'benchmarks/{dev_name}.json', 'w') as f: json.dump(results, f) input('done') def run_increasing_caps()-
Expand source code
def run_increasing_caps(): qs = [ implementations.BinaryHeap, implementations.LossyRing, implementations.FIFORing, implementations.SortedArray, ] dev_name, host = benchmarks.get_device_id() with open(f'benchmarks/caps_{dev_name}.json', 'w') as f: for n in range(1, 40): for imp in (bar := tqdm.tqdm(qs)): iimp = imp.sized(n) print('###', iimp.__name__) try: t = time_queue_batched(iimp, num=10000, nevents=100, lam=10, delay=5) print(n, iimp.__name__, t, file=f, flush=True) print(' (prelim)', n, iimp.__name__.ljust(20), f'{t: 10.7f}us/ts') except Exception as ex: print(repr(ex)) def run_increasing_sizes()-
Expand source code
def run_increasing_sizes(): dev_name, host = benchmarks.get_device_id() with open(f'benchmarks/sizes_{dev_name}.json', 'w') as f: for n in range(1000, 1000_000, 1000): for imp in (bar := tqdm.tqdm(check)): if n > 10_000 and not (n % 10_000) == 0: continue if n > 100_000 and not (n % 100_000) == 0: continue print('###', imp.__name__) try: t = time_queue_batched(imp, num=n, nevents=10) print(n, imp.__name__, t, file=f, flush=True) print(' (prelim)', n, imp.__name__.ljust(20), f'{t: 10.7f}us/ts') except Exception as ex: print(repr(ex)) def time_queue_batched(QueueT: type[BaseQueue],
num=10000,
nevents=100,
lam=400,
delay=80)-
Expand source code
def time_queue_batched( QueueT: type[implementations.BaseQueue], num=10000, nevents=100, # assume dt = 0.025 lam = 400, # 100 Hz delay = 80 # 2 ms ): Nevents = nevents key = jax.random.PRNGKey(0) @jax.jit def f_loop(carry, _): del _ qs, total, key, t = carry key, key_next = jax.random.split(key) evs = jax.random.uniform(key, shape=(num,)) < 1 / lam queue, out = jax.vmap(lambda q: q.pop(t))(qs) queue = jax.vmap(lambda e, q: jax.lax.cond(e, lambda: q.enqueue(t + delay), lambda: q))(evs, qs) total = total + out.sum() return (queue, total, key_next, t + 1), None init = jax.vmap(lambda _: QueueT.init(delay))(jnp.full(num, 0)) # type: ignore runner = benchmarks.mkrunner_loop( f_loop, init=(init, 0, key, 0), length=Nevents * lam, groq_unroll=20 ) runs = [] for _ in range(NREPEATS): a = time.time() o = runner() b = time.time() if isinstance(o, Exception): print(repr(o)) else: runs.append(b-a) return np.mean(np.array(runs)) / (Nevents * lam) * 1e6 # implementations.BinaryHeap.sized(7), # implementations.LossyRing.sized(4), # implementations.FIFORing.sized(4), # implementations.SortedArray.sized(4), def time_queue_single(QueueT: type[BaseQueue])-
Expand source code
def time_queue_single(QueueT: type[implementations.BaseQueue]): lam = 400 # in units of dt delay = 80 # units of dt Nevents = 1_00 stream = mkev(lam, Nevents) def f_loop(carry, arg): queue, total = carry t, ev = arg queue, out = queue.pop(t) queue = jax.lax.cond(ev, lambda: queue.enqueue(t + delay), lambda: queue) total = total + out return (queue, total), None runner = benchmarks.mkrunner_loop( f_loop, init=(QueueT.init(delay), 0), xs=stream, groq_unroll=100 ) runs = [] for _ in range(NREPEATS): a = time.time() o = runner() b = time.time() if isinstance(o, Exception): print(QueueT) print(repr(o)) else: runs.append(b-a) return np.mean(np.array(runs)) / stream.shape[0] * 1e6