Auto differentiation and XLA

manual diff


  • numerical, symbolic
  • autodiff
    • similar to symbolic, but on demand?
    • instead of expression returns numerical value

Forward mode

  • compute the partial diff of each scalar wrt each inputs in a forward pass

  • represented with tuple of original and primal (tangent)

  • Jax uses operator overloading.

Reverse mode

  • store values and dependencies of intermediate variables in memory
  • After forward pass, compute partial diff output wrt the intermediate adjoint