Auto differentiation and XLA
← manual diff
Others:
- numerical, symbolic
- autodiff
- similar to symbolic, but on demand?
- instead of expression → returns numerical value
Forward mode
-
compute the partial diff of each scalar wrt each inputs in a forward pass
-
represented with tuple of original and primal (tangent)
-
Jax uses operator overloading.
Reverse mode
- store values and dependencies of intermediate variables in memory
- After forward pass, compute partial diff output wrt the intermediate adjoint