Numpy + Autograd. Use XLA to compile and run NumPy code on accelerators.
Asynchronous dispatch, for sync use block_until_ready()
- notable function:
jit()
for compilation of multiple computationsgrad()
for performing transformation (autodiff, Jacobian-vector product)vmap()
for auto-vectorisation
Arrays are immutable in Jax
- Treat functions as pure as to compiled with XLA