JAX
Hey, this won’t be a tutorial on JAX, how to use it, or anything like that. It’s more about understanding the importance of the paradigm shift, and why it exists!
Strange useful terms
Here some useful terms that you’ll find in the following post:
JIT
It stands for Just-In-Time and refers to the compilation process. It is also called dynamic teanslation or run-time compilation, and as the name suggests, it allows the compilation during the execution of a program (i.e. at run-time) rather than before the execution.
Concretely, it’s a compiler; but it compiles only pieces that are needed to the program in the right after.
Open XLA
OpenXLA is a project that develops and governs XLA, which stands for (Accelerated Linear Algebra). It is an open-source compiler designed to improve the performance of machine learning models from popular frameworks like TensorFlow, PyTorch and guess who…. JAX!
OpenXLA framweork.
It improves and optimize the low-level computation graph across multiple hardware platforms: GPUs, CPUs and ML accelerator. The best part? The major industry-leading ML hardware and software companies build it collaboratively!
All together we obtain a framework which can be built and run anywhere since it is integrated with the leading ML frameworks and supports various backends.
Functional programming
I’ve found a super cool example on the web, that explains clearly how the paradigm shift from imperative programming (PyTorch) to functional programming (JAX)
def factorialImperative(n):
acc = 1
for i in range(1, n+1):
acc = acc*i
return acc
def factorialFunctional(n):
if n == 1:
return 1
else:
return n * factorialFunctional(n-1)
There’s no more to say about that, in the former paradigm you leverage objects and variable, while in the latter you focus on pure functions (i.e. no side effects) which lead to an easier way to reason.
SPMD
Also, Single-Program Multiple-Data. It is a parallel programming model (I’m deep delving in it, so probably you can expect a post about it) where a single program is getting executed from multiple processors independently, and each of these will operate on different parts of the data.
This turns out that each processor may follow a different execution path of the shared code depending on the data; however, the code itself is the same.
Why this matters? This model enables large-scale data processing by distributing the workloads, without the need of write separate programs for each processor. This is commonly used in libraries like MPI or OpenMP.
Just to give you a flavor using pseudocode
int id = get_process_id(); # PID
int n = total_elements / num_processes; # num of data block per process
for (int i = id * n; i < (id + 1) * n; i++) {
C[i] = A[i] + B[i];
}
here each process runs the very same code but works on different indices, determined by its id.
Cool, isn’t it?
JAX vs PyTorch
I’m a PyTorch user since (almost) always. I’m deeply nested within it’s imperative programming style. When building AI, I always choose PyTorch Lightning, for its simplicity and scaling capabilities.
I mean, can you imagine having to write .to(device) every single time? How tedious!
JK (not really), kudos for PyTorch devs!
However, this to say, that when I faced with JAX and this crappy things called functional programming, it opened a whole new world. It is completely different compared to the definition of models and classes, and to the eager (i.e. not lazy) execution of the operations.
But why do we need JAX, we have already PyTorch, is it just another framework that does the very same things ad PyTorch but in a different way? Well, not exactly, there’s a bunch of truth in saying so, but not completely.
PyTorch is more user-friendly, the imperative paradigm in general is more oriented in make everything clear and explicitly defined, which is directly linked with an ease of use and to an attraction a bigger basin of users.
So, why does JAX exists
Why someone who (for example) works in PyTorch should care about JAX, PyTorch works so well ans is really intuitive, why do someone must make its life harder with a functional framework like JAX?
It pushes computation closer to the compiler, with benefits in parallelization (SPMD was just a flavor) and hardware abstraction that are nearly impossible to obtain using imperative frameworks like PyTorch. JAX is built for performance!
That’s where the magic of things like jit, vmap, and pmap come inn. Instead of optimizing code, it lets the compiler do the hard work.
Here’s an example to show this idea:
import jax
import jax.numpy as jnp
@jax.jit
def relu(x):
return jnp.maximum(x, 0)
print(relu(jnp.ones((3, 3)) * -1))
So, what’s the point? Here JAX compiles this function ones, and then runs it on whatever device you are planning to use, CPU, GPU or TPU with (almost) no code changes! That’s a complete paradigm shift.
What makes it different
JAX takes a look at your code and applies:
-
jitfor compilation, -
gradfor automatic differentiation, -
vmapfor vectorization, -
pmapfor distributed computation,
all without changing your core logic. You are not telling your machine what to do step-by-step; you define what needs to be computed, and JAX decides how to run it efficiently.
Should you learn JAX
Do you:
- make large-scale experiments?
- perform heavy optimization?
- want to understand modern compiler-based systems?
if so, you should absolutely learn it!
And that’s the goal of such post: not to learn JAX or make JAX feel easy, but to make its design decisions make sense.