Wednesday, January 8, 2025

JAX Fundamentals with NNX Flax: A PyTorch Dev’s Guide

- Advertisement -

NNX Flax

A library called NNX Flax was created to improve the use of Flax, a powerful tool for creating neural networks in JavaScript. It offers resources designed to make training, experimentation, and model creation in machine learning workflows easier. In keeping with JAX’s functional programming model, NNX offers utilities that enhance Flax’s modularity.

The PyTorch developer’s guide to JAX fundamentals

Like many PyTorch users, you may have heard wonderful things about JAX, including its robust built-in support for parallel processing, its excellent performance, and the elegance of its functional programming approach. But you might have also had trouble finding the resources you need to get started. This simple, easy-to-follow tutorial will help you grasp the fundamentals of JAX by relating its novel ideas to the PyTorch building blocks you already know.

- Advertisement -

In order to train a rudimentary neural network in both frameworks for the well-known machine learning (ML) challenge of predicting which passengers survived the Titanic disaster, it examine the fundamentals of the JAX ecosystem through the perspective of a PyTorch user in this lesson. It introduce JAX along the way by showing how many things map to their PyTorch equivalents, from training to model definitions and instantiation.

Modularity with JAX

The highly modularised ecology of Jax may seem very odd to you at first if you’re a PyTorch user. The main goal of JAX is to be a high-performance library for numerical computations that supports automatic differentiation. It does not attempt to provide explicit built-in support for defining neural networks, optimisers, etc., as PyTorch does. Rather, JAX is made to be adaptable, so you may add your own frameworks to expand its capability.

The utilise two widely used and well-supported libraries in this tutorial: the Flax Neural Network library and the Optax optimisation library. For a truly PyTorch-like experience, it demonstrate how to train a neural network using the new NNX Flax API. The next demonstrate how to accomplish the same task using the more traditional but still popular Linen API.

Functional programming

Prior to starting to course, let’s discuss why JAX uses functional programming rather than object-oriented programming like PyTorch and other frameworks do. In a nutshell, functional programming is concerned with pure functions that always yield the same result for the same input since they are incapable of changing state or producing side effects. This is seen in JAX by the extensive use of immutable arrays and composable functions.

- Advertisement -

Numerous advantages of JAX, including Just-In-Time (JIT) compilation, where the XLA compiler may greatly optimise code on GPUs or TPUs for considerable speedups, are made possible by the predictability of pure functions and functional programming. Additionally, they greatly simplify JAX sharding and parallelisation processes.

NNX Flax conceals a lot of functional programming beneath common Pythonic conventions, so don’t be discouraged if you’re new to it.

Data loading

It’s quite easy to import data in JAX; simply follow your PyTorch workflow. All JAX processing relies on Numpy-like arrays, which can be converted using a PyTorch dataset/dataloader and a simple collate-fn.

Model definition

The NNX API in Flax is a lot like PyTorch when it comes to defining neural networks. Here, it begin with PyTorch and define a basic two-layer multilayer perceptron in both frameworks.

The PyTorch code above and the definitions of the NNX model are fairly comparable. Both utilise init to provide the model’s layers, however call stands for forward.

Model initialization and usage

In NNX, model initialisation is almost the same as in PyTorch. The model parameters are eagerly (as opposed to slackily) initialised and bound to the instance itself when you instantiate an instance of the model class in both frameworks. The sole distinction with NNX is that a pseudorandom number generator (PRNG) key must be entered when the model is instantiated. Jax avoids implicit global random state and requires you to actively pass PRNG keys, which is consistent with its functional nature. This makes it simple to vectorise, parallelise, and reproduce PRNG production.

In actuality, the two frameworks are equal when it comes to employing the models to process a batch of data.

Training step and backpropagation

PyTorch and NNX Flax training loops differ in a few significant ways. Let’s gradually build up to the complete NNX training loop to illustrate.

It are able to declare to optimisation algorithm and create Optimiser in both frameworks. NNX Flax lets you pass in the model directly and manages all interactions with the underlying Optax optimiser, unlike PyTorch requires you to send in model parameters.

Forward + backward pass

The ability to execute a complete forward/backward pass is arguably the most significant distinction between PyTorch and JAX. You compute the gradients with loss using PyTorch.backward, which causes AutoGrad to compute the gradients by following the computation graph from loss.

Instead, JAX’s automatic differentiation is considerably more like the raw math, which involves function gradients. In particular, grad/nnx.grad and nnx.value take in a function called loss_fn and return a function called grad_fn. The gradient of the output of loss_fn with respect to its input is then returned by grad_fn itself.

In these example, loss_fn does the exact same action as in PyTorch: it computes the well-known loss after first obtaining the logits from the forward pass. Grad_fn then determines the gradient of loss in relation to the model’s parameters. The grades that are returned in mathematical terms. Under the hood, PyTorch is doing precisely this: when you perform loss, PyTorch is “storing” the gradients in the tensor’s.grad attribute. Backward, JAX and NNX Flax use the functional strategy of just returning the gradients to you without changing the state.

Optimiser step

Optimiser step in PyTorch uses the gradients to update the weights that are already in place. NNX also updates the weights in-place, but it needs the grades you determined in the backward pass to be entered directly. Just a little more clear in accordance with Jax’s fundamental functional character, this optimisation phase is identical to that carried out in PyTorch.

Full training loop

Everything you need to build a complete training loop in JAX/NNX Flax is now in your possession. Let’s first view the well-known PyTorch loop for reference.

And now the full NNX training loop

The main conclusion is that PyTorch and JAX/NNX Flax have fairly comparable training loops, with the majority of the variations resulting from object-oriented versus functional programming. Functional programming and considering function gradients have a minor learning curve, but they make possible many of the previously described advantages of JAX, such as automatic parallelisation and JIT compilation. For instance, training the model for 500 epochs on Kaggle takes only 1.8 minutes on a P100 GPU when the nnx.jit annotations are added to the routines. The same code will produce comparable speedups on CPUs, TPUs, and even non-NVIDIA GPUs.

Flax Linen reference

As stated before, the JAX environment is highly adaptable and allows you to use any framework you like. Even while the Flax Linen API is still frequently used today, especially in robust frameworks like MaxText and MaxDiffusion, NNX is still the suggested option for novice users. Linen follows pure functional programming much more closely than NNX, which is significantly more Pythonic and conceals much of the intricacy of state management.

If you wish to be involved in the JAX ecosystem, it is quite helpful to be at ease with both. To assist, let’s use Linen to duplicate a large portion of these NNX code and add comments outlining the key distinctions.

- Advertisement -
Drakshi
Drakshi
Since June 2023, Drakshi has been writing articles of Artificial Intelligence for govindhtech. She was a postgraduate in business administration. She was an enthusiast of Artificial Intelligence.
RELATED ARTICLES

Recent Posts

Popular Post

Govindhtech.com Would you like to receive notifications on latest updates? No Yes