The PyTorch/XLA Python package is essential for developers who wish to utilise the PyTorch deep learning framework with Cloud TPUs since it allows them to execute their PyTorch models on Cloud TPUs with only a few small code modifications. It accomplishes this by using Google’s OpenXLA, which enables developers to describe their model only once and execute it on a variety of machine learning accelerators (such as GPUs, TPUs, etc.).
PyTorch/XLA’s most recent release includes a number of enhancements that enhance developers’ experience:
- A novel experimental scan operator to expedite compilation for code blocks that repeat (for loops, for example)
- To accommodate bigger models on fewer TPUs, host offloading transfers TPU tensors to the host CPU’s memory.
- Increased output for models that are tracing-bound through a new basis The C++ 2011 Standard application binary interface (C++ 11 ABI) flags were included in the Docker image compilation.
Google Cloud has redesigned the documentation to make it easier to find information along with these improvements!
Examine each trait in depth.
Experimental scan operator
Include you ever encountered lengthy compilation times, such as when utilising PyTorch/XLA and huge language models, particularly when working with models that include several decoder layers? These iterative loops are fully “unrolled” during graph tracing, which involves traversing the graph of all the operations the model is performing. This means that each loop iteration is duplicated and pasted for each cycle, creating massive computation graphs. Longer compilation times are a direct result of these bigger graphs. However, a new solution has emerged: the experimental scan function, which was modelled after jax.lax.scan.
By altering the way loops are handled during compilation, the scan operator operates. Scan only compiles the first iteration of the loop rather than compiling each iteration separately, which results in duplicated blocks. Every iteration after that uses the compiled high-level operation (HLO) that was produced. This indicates that for every loop that follows, less HLO or intermediate code is created. Because scan only compiles the first loop iteration, it takes a fraction of the time to compile as opposed to a for loop. When working on models like LLMs that have a lot of homogenous layers, this shortens the developer’s iteration time.

The torch_xla.experimental.scan_layers method provides a streamlined interface for looping across nn.Module sequences, building upon torch_xla.experimental.scan. Consider it a means of communicating with PyTorch/XLA “These modules are all the same, just compile them once and reuse them!” For instance:
<div><br class=”Apple-interchange-newline”>import torch <br></div>
import torch
import torch.nn as nn
import torch_xla
from torch_xla.experimental.scan_layers import scan_layers
class DecoderLayer(nn.Module):
def __init__(self, size):
super().__init__()
self.linear = nn.Linear(size, size)
def forward(self, x):
return self.linear(x)
with torch_xla.device():
layers = [DecoderLayer(1024) for _ in range(64)]
x = torch.randn(1, 1024)
# Instead of a for loop, we can scan_layers once:
# for layer in layers:
# x = layer(x)
x = scan_layers(layers, x)
It should be noted that scan is not currently supported by modified Pallas kernels. For reference, here is a comprehensive example of how to use scan_layers in an LLM.
Host offloading
Host offloading is another effective memory optimisation technique in PyTorch/XLA. By using this method, you may temporarily transfer tensors from the TPU to the memory of the host CPU, freeing up important device memory for training. For big models where memory pressure is an issue, this is quite beneficial. A tensor can be offloaded using torch_xla.experimental.stablehlo_custom_call.place_to_host, and it can be retrieved later using torch_xla.experimental.stablehlo_custom_call.place_to_device. Intermediate activations are often offloaded during the forward pass and then brought back during the backward pass. For reference, here is an example of host unloading.
Training huge and complicated models within the memory restrictions of your hardware may be greatly enhanced by making strategic use of host offloading, such as when you’re dealing with limited memory and can’t use the accelerator continually.
A different base Docker image
Have you ever traced your model execution graph for just-in-time compilation and found that your TPUs are idle while your host CPU is significantly loaded? This implies that your model is “tracing bound,” which means that the pace at which tracing operations are carried out limits performance.
The answer is provided by the C++11 ABI image. With this version, PyTorch/XLA provides a selection of C++ ABI flavours for Docker images and Python wheels. This allows you to select the C++ version you want to use with PyTorch/XLA. Both the more recent C++11 ABI and the older pre-C++11 ABI, which is still the default to match PyTorch upstream, are now available in builds.
The aforementioned situations can be significantly improved by switching to the C++11 ABI wheels or Docker images. For instance, when Google Cloud moved from the pre-C++11 ABI to the C++11 ABI! ML, Google Cloud saw a 20% relative gain in goodput with the Mixtral 8x7B model on v5p-256 Cloud TPU (with a global batch size of 1024). Goodput helps us understand how well a certain model uses the hardware. Therefore, greater model performance is shown if the same model’s goodput measurement is higher on the same hardware.
The following is an example of how to use a C++11 ABI docker image in your Dockerfile:
<div><br class=”Apple-interchange-newline”># Use the C++11 ABI PyTorch/XLA image as the base <br></div>
# Use the C++11 ABI PyTorch/XLA image as the base
FROM us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.6.0_3.10_tpuvm_cxx11
# Install any additional dependencies here
# RUN pip install my-other-package
# Copy your code into the container
COPY . /app
WORKDIR /app
# Run your training script
CMD ["python", "train.py"]
As an alternative, you may use the following command (example in Python 3.10) to install the C++11 ABI wheels for version 2.6 if you are testing locally and are not utilising Docker images:
pip install torch==2.6.0+cpu.cxx11.abi \
https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0%2Bcxx11-cp310-cp310-manylinux_2_28_x86_64.whl \
'torch_xla[tpu]' \
-f https://storage.googleapis.com/libtpu-releases/index.html \
-f https://storage.googleapis.com/libtpu-wheels/index.html \
-f https://download.pytorch.org/whl/torch
Python 3.10 supports the aforementioned command. Google Cloud’s documentation has instructions for different versions.
The ability to select from a variety of C++ ABIs allows you to select the best build for your particular workload and hardware, which will eventually improve the performance and efficiency of your PyTorch/XLA applications!
An observation on GPU support
The PyTorch/XLA 2.6 version does not include a PyTorch/XLA:GPU wheel. Google Cloud want to restore GPU functionality for the 2.7 release because Google Cloud recognise how vital this is. PyTorch/XLA is still an open-source project, and community donations are encouraged to support its upkeep and enhancement.