Jax

JAX has 4 main function transformations

  • grad() to automatically differentiate a function

  • vmap() to automatically vectorize operations

  • pmap() for parallel computation of SPMD programs

  • jit() to transform a function into a JIT-compiled version

These transformations are (mostly) composable, very powerful, and have the potential to expedite your programs several times over.

https://www.assemblyai.com/blog/why-you-should-or-shouldnt-be-using-jax-in-2022/