Friday, March 28, 2025

JAX: How to Use High-Performance Computing For ML

A Brief Introduction to JAX: Machine Learning with High-Performance Computing.

AI frameworks have revolutionized the development of deep learning and machine learning applications, allowing for quicker deployment and training. Because of their large feature sets, robust ecosystems, and supportive communities, PyTorch and TensorFlow have emerged as the most popular frameworks. The Just After eXecution (JAX) framework was created for array-oriented numerical computations by offering an interface akin to NumPy, despite the fact that these frameworks predominate in the AI field.

This framework accelerates deep learning model calculations with CPUs, GPUs, and TPUs, making it perfect for high-performance computing and machine learning research. By letting people solve cutting-edge AI challenges and research, JAX will improve PyTorch and TensorFlow.

JAX Features

The following features are offered by JAX:

  • Compute using a unified NumPy-like interface on a CPU, GPU, or TPU.
  • Integrated Just-In-Time (JIT) compilation: Open XLA, an open-source machine learning compiler ecosystem, optimizes calculations on hardware accelerators with jit, leading to quicker deep learning model training times. Jit can be utilized as a higher order function or as a @jit decorator.
  • Automatic differentiation transformations are necessary for many deep learning algorithms in order to efficiently calculate gradients. Grad is the most widely used function for reverse-mode gradients.
  • To effectively map JAX functions over arrays that represent batches of inputs, automatic vectorization is used. vmap, or vectorizing map, has the well-known semantics of mapping a function along array axis, but for improved efficiency, it pulls the loop down into a function’s primitive operations rather than leaving it on the outside.

JAX’s NumPy-like interface makes it particularly well-suited for scientific computing and optimization jobs when compared to competing frameworks like PyTorch and TensorFlow. Its JIT integration also requires very little extra code from the user.

Code Sample

The code sample demonstrates how to use JAX for parallel computations over many CPU cores to create a basic neural network that will be trained on the MNIST dataset. It emphasizes using JAX’s “pmap” to minimize dependencies while running single-program multiple-data (SPMD) programs for data parallelism along a batch dimension. The code example implements the following actions and features.

  • Bring in the required libraries and packages.
from functools import partial
import time
 
import numpy as np
import numpy.random as npr
 
import jax
from jax import jit, grad, pmap
from jax.scipy.special import logsumexp
from jax.tree_util import tree_map
from jax import lax
import jax.numpy as jnp
from examples import datasets
  • To set the weights and biases for every layer in the neural network, define the init_random_parameters function.
def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)):
  return [(scale * rng.randn(m, n), scale * rng.randn(n))
          for m, n, in zip(layer_sizes[:-1], layer_sizes[1:])]
  • Define the predict function, which uses input weights, biases, and activations to calculate the network’s forward pass.
def predict(params, inputs):
  activations = inputs
  for w, b in params[:-1]:
    outputs = jnp.dot(activations, w) + b
    activations = jnp.tanh(outputs)
 
  final_w, final_b = params[-1]
  logits = jnp.dot(activations, final_w) + final_b
  return logits - logsumexp(logits, axis=1, keepdims=True)
  • To determine the cross-entropy loss between predictions and target labels, define the loss function.
def loss(params, batch):
  inputs, targets = batch
  preds = predict(params, inputs)
  return -jnp.mean(jnp.sum(preds * targets, axis=1))
  • Define the accuracy function, which predicts each batch input’s class and compares it to the actual target class to determine the model’s accuracy. After determining the projected class using the jnp.argmax function, it calculates the mean of accurate predictions.
def accuracy(params, batch):
  inputs, targets = batch
  target_class = jnp.argmax(targets, axis=1)
  predicted_class = jnp.argmax(predict(params, inputs), axis=1)
  return jnp.mean(predicted_class == target_class)
  • Create batches of shuffled training data by defining the data_stream function. In order to ensure that the batch size is divisible by the number of cores for parallel processing, it restructures the data such that it can be divided among several cores.
def data_stream():
    rng = npr.RandomState(0)
    while True:
      perm = rng.permutation(num_train)
      for i in range(num_batches):
        batch_idx = perm[i * batch_size:(i + 1) * batch_size]
        images, labels = train_images[batch_idx], train_labels[batch_idx]
        # For this SPMD example, we reshape the data batch dimension into two
        # batch dimensions, one of which is mapped over parallel devices.
        batch_size_per_device, ragged = divmod(images.shape[0], num_devices)
        if ragged:
          msg = "batch size must be divisible by device count, got {} and {}."
          raise ValueError(msg.format(batch_size, num_devices))
        shape_prefix = (num_devices, batch_size_per_device)
        images = images.reshape(shape_prefix + images.shape[1:])
        labels = labels.reshape(shape_prefix + labels.shape[1:])
        yield images, labels
  • Define the spmd_update function, which uses JAX‘s pmap and lax.psum to carry out parallel gradient changes across several devices.
  @partial(pmap, axis_name='batch')
  def spmd_update(params, batch):
    grads = grad(loss)(params, batch)
    # `lax.psum` SPMD primitive - does a fast all-reduce-sum.
    grads = [(lax.psum(dw, 'batch'), lax.psum(db, 'batch')) for dw, db in grads]
    return [(w - step_size * dw, b - step_size * db)
            for (w, b), (dw, db) in zip(params, grads)]

In order to train the model for a number of epochs, create a training loop. After each epoch, report the training/test accuracy and update the parameters. Using spmd_update, the parameters are changed concurrently and replicated across devices. Following each epoch, accuracy is used to assess the model’s performance on both training and test data.

To implement training and inference for mints pictures using JAX on CPU, try out and execute the code example above. After training the network across several epochs, we will use stochastic gradient descent to assess accuracy and modify parameters.

What Comes Next

Start using JAX on various datasets, such as the Sentiment140 dataset, to analyze sentiment and carry out intricate numerical calculations on powerful devices. Additionally, take a look at the Intel Extension for TensorFlow, which incorporates the PJRT (Pluggable Device Runtime) plugin implementation, which allows JAX models to run smoothly on Intel GPUs. Intel optimizes the open source TensorFlow framework for Intel hardware and releases its latest optimizations and features in Intel Extension.

To create more end-to-end AI apps, have access to and test out the AI Tools. To assist you in planning, developing, implementing, and scaling your AI solutions, it also encourage you to review and integrate Intel’s other AI/ML Framework optimizations and tools into your AI workflow. Additionally, you can learn about the unified, open, standards-based oneAPI programming model that serves as the cornerstone of Intel’s AI Software Portfolio.

Drakshi
Drakshi
Since June 2023, Drakshi has been writing articles of Artificial Intelligence for govindhtech. She was a postgraduate in business administration. She was an enthusiast of Artificial Intelligence.
RELATED ARTICLES

Recent Posts

Popular Post