In the burgeoning field of machine learning and deep learning, choosing the right framework is akin to selecting the right tool for a craftsman. It defines ease of use, efficiency, and the boundaries of what can be achieved.

Google’s JAX is an emerging library that promises to redefine how we approach machine learning programming with its unique features and capabilities.

For those who have closely followed the evolution of ML libraries — from TensorFlow to PyTorch and now JAX — this shift is akin to the arrival of a new ace in a deck of cards.

JAX logo

Today, we’re diving deep into what makes JAX stand out and how it can be leveraged for sophisticated ML tasks.

Installation and setup

Before we venture into the rich functionalities of JAX, it’s important to set up your environment. Getting JAX up and running is straightforward with pip:


pip install --upgrade jax jaxlib

For those looking to harness GPU capabilities, ensure you have CUDA and cuDNN installed, then obtain the right jaxlib version corresponding to your CUDA setup:


$ pip install --upgrade jax jaxlib==0.1.61+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html

Now, let’s see JAX in action alongside NumPy for some initial comparisons:


import numpy as np
import jax.numpy as jnp
import timeit

Exploring the core of JAX

JAX enhances numerical computations with an API highly reflective of NumPy, making it instantly familiar. Yet, it adds a layer of hardware acceleration through the DeviceArray, enabling computations on GPUs and TPUs without altering your high-level code.

Consider the simplistic operation of creating an array of zeros:


x = np.random.rand(7000, 7000)
y = jnp.array(x)  # Creating a JAX DeviceArray

While identical in syntax, the representation and operation of these two arrays are vastly different at the computational level, with y poised to leverage accelerated hardware seamlessly.

 Speed, JIT, and automatic differentiation

Performance is a domain where JAX shines brightly, remarkably outpacing traditional CPU-bound operations. An illustrative example, performing the dot product, showcases JAX’s prowess:


# Defining the functions for timing
def numpy_dot():
    np.dot(x, x)

def jax_dot():
    jnp.dot(y, y).block_until_ready()

# Timing the NumPy operation
numpy_time = timeit.timeit(numpy_dot, number=1)

# Timing the JAX operation
jax_time = timeit.timeit(jax_dot, number=1)

print(f"NumPy dot product time: {numpy_time} seconds")
print(f"JAX dot product time: {jax_time} seconds")

The difference in execution time is stark, demonstrating JAX’s ability to make full use of hardware acceleration.

On my M1 Mac, results looks like this:

  • NumPy dot product time: 3.64 seconds
  • JAX dot product time: 1.72 seconds

But JAX is more than a one-trick pony focusing on array computations; it revolutionizes other aspects of machine learning development, most notably automatic differentiation and just-in-time (JIT) compilation.

Automatic differentiation is a cornerstone of deep learning, allowing for seamless backpropagation. With JAX’s grad function, obtaining derivatives of complex functions becomes as straightforward as:


from jax import grad

def f(x):
  return 3 * x ** 2 + 2 * x + 5

grad(f)(1.0)  # Automatically computes the derivative at x = 1.0

JIT compilation via the jit decorator or function accelerates operations further by compiling Python functions to highly optimized machine code:


from jax import jit

@jit
def myFunction(x):
  # Some complex operation

Parallel computation and vectorization

For tasks that demand distribution across devices or cores, JAX introduces pmap (parallel map) and vmap (vectorization map), enabling an effortless scaling and optimization of computations.

Here’s a glance at pmap in action:


from jax import pmap

@pmap
def myParallelFunction(x):
  # Operation now automatically distributed across available devices

And vmap adds a layer of batch processing capability to any function designed for single inputs:


from jax import vmap

@vmap
def myVectorizedFunction(x):
  # Operation now supports batch inputs seamlessly

Conclusion

JAX represents a significant leap forward in the machine learning and deep learning toolkit, merging the intuitive nature of Numpy with the raw power of GPUs and TPUs. Whether it’s the ease of performing complex mathematical operations, the efficiency of JIT compilation, or the sophistication of automatic differentiation, JAX has set a new benchmark for what’s possible.

As the community around JAX grows and its ecosystem expands, adopting JAX not only future-proofs your projects but also opens doors to computational possibilities previously considered challenging or outright infeasible.

For those eager to dive deeper, exploring the rich set of libraries built on JAX, like Flax for neural networks or Optax for optimization, will further unlock the potential of this powerful library. With JAX, the future of machine learning programming looks brighter, more structured, and infinitely more exciting.

Last Update: 09/03/2024