Numpy + Autograd. Use XLA to compile and run NumPy code on accelerators.
Asynchronous dispatch, for sync use block_until_ready()
- notable function:
jit()
for compilation of multiple computationsgrad()
for performing transformation (autodiff, Jacobian-vector product)vmap()
for auto-vectorisation
Arrays are immutable in Jax
- Treat functions as pure as to compiled with XLA
references: github
control flow
see also link
The following works:
doesn't work
Reasoning: jit
traces code on ShapedArray
abstraction, where each abstract value represents the set of all array values with a fixed shape and dtype
type coercion tradeoff
If we trace a Python function on a
ShapedArray((), jnp.float32)
that isn’t committed to a specific concrete value, when we hit a line like ifx < 3
, the expression x < 3 evaluates to an abstractShapedArray((), jnp.bool_)
that represents the set{True, False}
.
Fix: you can use static_argnums
to specify which argument should be treated as static
buffers
How does JAX handle memory buffers?