Numpy + Autograd. Use XLA to compile and run NumPy code on accelerators. Asynchronous dispatch, for sync use block_until_ready()

import jax.numpy as jnp
from jax import random
key = random.PRNGKey(0)
x = random.normal(key, (10,)), x.T).block_until_ready()
  • notable function:
    • jit() for compilation of multiple computations
    • grad() for performing transformation (autodiff, Jacobian-vector product)
    • vmap() for auto-vectorisation

Arrays are immutable in Jax

  • Treat functions as pure as to compiled with XLA
import jax.numpy as jnp
from jax import jit
def diff(a, b, w=0):
    return, b) + w