---
date: '2022-11-07'
description: numpy with autograd and xla compilation for accelerators, featuring jit compilation, automatic differentiation, and vectorization with immutable arrays.
id: Jax
modified: 2026-06-05 15:08:07 GMT-04:00
tags:
  - seed
  - ml
title: Jax
created: '2022-11-07'
published: '2022-11-07'
pageLayout: default
slug: thoughts/Jax
permalink: https://aarnphm.xyz/thoughts/Jax.md
generator:
  quartz: v4.6.0
  hostedProvider: Cloudflare
  baseUrl: aarnphm.xyz
full: https://aarnphm.xyz/llms-full.txt
---
Numpy + [[thoughts/Autograd|Autograd]]. Use [[thoughts/XLA|XLA]] to compile and run NumPy code on accelerators.

- notable function:
  - `jit()` for compilation of multiple computations
  - `grad()` for performing transformation (autodiff, [[thoughts/Vector calculus#Jacobian matrix|Jacobian]]-vector product)
  - `vmap()` for auto-vectorisation

> Arrays are **immutable** in Jax

- Treat functions as pure as to compiled with [[thoughts/XLA]]

```python title="entropix/dslider.py"
from functools import partial
from typing import NamedTuple, Tuple

import jax
import jax.numpy as jnp
import jax.scipy as jsp


@jax.jit
def kl_divergence(logp: jnp.ndarray, logq: jnp.ndarray) -> jnp.ndarray:
  """Compute KL divergence between two log probability distributions."""
  p = jnp.exp(logp)
  return jnp.sum(jnp.where(p > 0, p * (logp - logq), 0.0), axis=-1)


@jax.jit
def ent_varent(logp: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
  """Compute entropy and varentropy from log probabilities."""
  p = jnp.exp(logp)
  ent = -jnp.sum(p * logp, axis=-1)
  diff = logp + ent[..., None]
  varent = jnp.sum(p * diff**2, axis=-1)
  return ent, varent


@jax.jit
def normalize_logits(logits: jnp.ndarray, noise_floor: float) -> jnp.ndarray:
  """Normalize logits to log probabilities with noise floor truncation."""
  shifted = logits - jnp.max(logits, axis=-1, keepdims=True)
  normalized = shifted - jax.nn.logsumexp(
    shifted + EPS, axis=-1, keepdims=True
  )
  # noise floor calculated for bfloat16
  return jnp.where(normalized < noise_floor, jnp.log(EPS), normalized)
```

_references: [github](https://github.com/xjdr-alt/entropix/blob/main/entropix/dslider.py)_

## asynchronous dispatch

for sync use `block_until_ready()`

<div class="notebook-runtime" data-notebook-runtime="notebook-runtime-9uw1v4"></div>

<div class="notebook-code-cell" data-notebook-cell-frame="code-cell-1" id="code-cell-1" data-notebook-language="python">

<div class="notebook-runtime-cell" data-notebook-cell="code-cell-1" data-notebook-execution-count=""><span class="notebook-execution-prompt" data-notebook-execution-label="code-cell-1" aria-live="polite">In [ ]:</span></div>

<div class="notebook-cell-actions" data-notebook-cell-actions="code-cell-1">
<span class="notebook-language-badge notebook-language-badge-python" data-notebook-language="python" title="Python cell"><span class="notebook-language-icon" aria-hidden="true"><svg class="notebook-language-svg notebook-python-icon" viewBox="0 0 111 112" aria-hidden="true" focusable="false"><path fill="#3776ab" d="M54.918785.00091927421C50.335132.02221727 45.957846.41313697 42.106285 1.0946693 30.760069 3.0991731 28.700036 7.2947714 28.700035 15.032169v10.21875h26.8125v3.40625h-36.875c-7.792459 0-14.6157588 4.683717-16.7499998 13.59375-2.46181998 10.212966-2.57101508 16.586023 0 27.25 1.9059283 7.937852 6.4575432 13.593748 14.2499998 13.59375h9.21875v-12.25c0-8.849902 7.657144-16.656248 16.75-16.65625h26.78125c7.454951 0 13.406253-6.138164 13.40625-13.625v-25.53125c0-7.2663386-6.12998-12.7247771-13.40625-13.9374997C64.281548.32794397 59.502438-.02037903 54.918785.00091927421zM40.418785 8.2196694c2.769547 0 5.03125 2.2986456 5.03125 5.1249996-.000002 2.816336-2.261703 5.09375-5.03125 5.09375-2.779476-.000001-5.03125-2.277415-5.03125-5.09375-.000001-2.826353 2.251774-5.1249996 5.03125-5.1249996z"/><path fill="#ffd43b" d="M85.637535 28.657169v11.90625c0 9.230755-7.825895 16.999999-16.75 17h-26.78125c-7.335833 0-13.406249 6.278483-13.40625 13.625v25.531247c0 7.266344 6.318588 11.540324 13.40625 13.625004 8.487331 2.49561 16.626237 2.94663 26.78125 0 6.750155-1.95439 13.406253-5.88761 13.40625-13.625004V86.500919h-26.78125v-3.40625h40.187504c7.792461 0 10.696251-5.435408 13.406241-13.59375 2.79933-8.398886 2.68022-16.475776 0-27.25-1.92578-7.757441-5.60387-13.59375-13.406241-13.59375zm-15.0625 64.65625c2.779478.000003 5.03125 2.277417 5.03125 5.093747-.000002 2.826354-2.251775 5.125004-5.03125 5.125004-2.76955 0-5.03125-2.29865-5.03125-5.125004.000002-2.81633 2.261697-5.093747 5.03125-5.093747z"/></svg></span><span class="notebook-language-label">Python cell</span></span>
<button type="button" class="notebook-icon-button" data-notebook-run-cell="code-cell-1" aria-label="Run code-cell-1" title="Run code-cell-1"><svg viewBox="0 0 24 24" aria-hidden="true" focusable="false"><path d="M8 5v14l11-7z"/></svg></button>
<button type="button" class="notebook-icon-button" data-notebook-edit-cell="code-cell-1" aria-label="Edit code-cell-1" title="Edit code-cell-1"><svg viewBox="0 0 24 24" aria-hidden="true" focusable="false"><path d="m4 16.5-.5 4 4-.5L19 8.5 15.5 5z"/><path d="m14 6.5 3.5 3.5"/></svg></button>
<button type="button" class="notebook-icon-button" data-notebook-save-cell="code-cell-1" aria-label="Save code-cell-1 locally" title="Save code-cell-1 locally" hidden><svg viewBox="0 0 24 24" aria-hidden="true" focusable="false"><path d="M5 4h11l3 3v13H5z"/><path d="M8 4v6h8V4"/><path d="M8 20v-6h8v6"/></svg></button>
<button type="button" class="notebook-icon-button" data-notebook-revert-cell="code-cell-1" aria-label="Revert code-cell-1 local edit" title="Revert code-cell-1 local edit" hidden><svg viewBox="0 0 24 24" aria-hidden="true" focusable="false"><path d="M9 14 4 9l5-5"/><path d="M4 9h10.5a5.5 5.5 0 0 1 0 11H11"/></svg></button>
<button type="button" class="notebook-icon-button" data-notebook-vim-cell="code-cell-1" aria-label="Enable Vim mode" title="Enable Vim mode" hidden><svg class="notebook-vim-icon" viewBox="0 0 602 734" aria-hidden="true" focusable="false"><g transform="translate(2 3)"><path class="notebook-vim-icon-left" d="M0 155.5704 155-1l-.000003 728L0 572.237919z"/><path class="notebook-vim-icon-right" d="M443.060403 156.982405 600-1l-3.181208 728L442 572.219941z" transform="translate(521 363.5) scale(-1 1) translate(-521 -363.5)"/><path class="notebook-vim-icon-cross" d="M154.986294 0 558 615.189696 445.224605 728 42 114.172017z"/></g></svg></button>
<span class="notebook-local-source-status" data-notebook-local-source-status="code-cell-1" hidden></span>
</div>

<div class="notebook-source-editor" data-notebook-source-editor="code-cell-1" hidden></div>

```python shell
import jax.numpy as jnp
from jax import random

key = random.PRNGKey(0)
x = random.normal(key, (10,))
jnp.dot(x, x.T).block_until_ready()
```

<div class="notebook-runtime-output" data-notebook-output="code-cell-1" hidden></div>

</div>

## control flow

see also [link](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#python-control-flow-jit)

The following works:

```python /jax.jit/
@jax.jit
def f(x):
  for i in range(3):
    x = 2 * x
  return x


print(f(3))


@jax.jit
def g(x):
  y = 0.0
  for i in range(x.shape[0]):
    y = y + x[i]
  return y


print(g(jnp.array([1.0, 2.0, 3.0])))
```

> \[!warning\]- doesn't work
>
> ```python {2-4}
> @jax.jit
> def fail(x):
>   if x < 3: return 3. * x ** 2
>   else    : return -4 * x
>
> fail(2)
> ```

Reasoning: `jit` traces code on `ShapedArray` abstraction, where each abstract value represents the set of all array values with a fixed shape and dtype

> \[!tip\]+ type coercion tradeoff
>
> If we trace a Python function on a `ShapedArray((), jnp.float32)` that isn’t committed to a specific concrete value,
> when we hit a line like if `x < 3`, the expression x < 3 evaluates to an abstract `ShapedArray((), jnp.bool_)` that represents the set `{True, False}`.

Fix: you can use `static_argnums` to specify which argument should be treated as static

```python
@jit(static_argnums=(0,))
def f(x):
  if x < 3:
    return 3.0 * x**2
  else:
    return -4 * x
```

## buffers

> \[!question\] How does JAX handle memory buffers?

[fast replay buffers](https://github.com/instadeepai/flashbax)

<script type="application/json" data-notebook-runtime-data>{"id":"notebook-runtime-9uw1v4","sourcePath":"thoughts/Jax.md","language":"python","indexUrl":"https://cdn.jsdelivr.net/pyodide/v0.29.4/full/","cells":[{"id":"code-cell-1","source":"import jax.numpy as jnp\nfrom jax import random\n\nkey = random.PRNGKey(0)\nx = random.normal(key, (10,))\njnp.dot(x, x.T).block_until_ready()","language":"python","executionIndex":null}],"toolbar":false,"debug":true,"vimMode":true}</script>

