Jax
JAX has 4 main function transformations
-
grad()
to automatically differentiate a function -
vmap()
to automatically vectorize operations -
pmap()
for parallel computation of SPMD programs -
jit()
to transform a function into a JIT-compiled version
These transformations are (mostly) composable, very powerful, and have the potential to expedite your programs several times over.
https://www.assemblyai.com/blog/why-you-should-or-shouldnt-be-using-jax-in-2022/