PyTorch/XLA 2.5
PyTorch/XLA 2.5: enhanced development experience and support for vLLM
PyTorch/XLA, a Python package that connects the PyTorch deep learning framework with Cloud TPUs via the XLA deep learning compiler, has machine learning engineers enthusiastic. Additionally, PyTorch/XLA 2.5 has arrived with a number of enhancements to improve the developer experience and add support for vLLM. This release’s features include:
- An explanation of the plan to replace the outdated torch_xla API with the current PyTorch API, which would simplify the development process. The transfer of the current Distributed API serves as an illustration of this.
- A number of enhancements to the torch_xla.compile function that enhance developers’ debugging experience when they are working on a project.
- You can expand your current deployments and use the same vLLM interface across all of your TPUs thanks to experimental support in vLLM for TPUs.
Let’s examine each of these improvements.
Streamlining torch_xla API
Google Cloud is making a big stride toward improving the consistency of the API with upstream PyTorch with PyTorch/XLA 2.5. Its goal is to make XLA devices easier to use by reducing the learning curve for developers who are already familiar with PyTorch. When feasible, this entails phasing out and deprecating proprietary PyTorch/XLA API calls in favor of more sophisticated functionality, then switching the API calls to their PyTorch equivalents. Before the migration, several features were still included in the current Python module.
It has switched to using some of the existing PyTorch distributed API functions when running models on top of PyTorch/XLA in this release to make the development process for PyTorch/XLA easier. In this release, it moved the majority of the calls for the distributed API from the torch_xla module to torch.distributed.
With PyTorch/XLA 2.4
import torch_xla.core.xla_model as xm
xm.all_reduce()
Supported after PyTorch/XLA 2.5
torch.distrbuted.all_reduce()
A better version of “torch_xla.compile”
To assist you in debugging or identifying possible problems in your model code, it also includes a few new compilation features. For instance, when there are many compilation graphs, the “full_graph” mode generates an error message. This aids in the early detection (during compilation) of possible problems brought on by various compilation graphs.
You may now also indicate how many recompilations you anticipate for compiled functions. This can assist you in troubleshooting performance issues if a function may be recompiled more frequently than necessary, such as when it exhibits unexpected dynamism.
Additionally, you can now give compiled functions a meaningful name rather than one that is generated automatically. When debugging messages, naming compiled targets gives you additional context, which makes it simpler to identify the potential issue. Here’s an illustration of how that actually appears in practice:
named code
@torch_xla.compile
def dummy_cos_sin_decored(self, tensor):
return torch.cos(torch.sin(tensor))
target dumped HLO renamed with named code function name
…
module_0021.SyncTensorsGraph.4.hlo_module_config.txt
module_0021.SyncTensorsGraph.4.target_arguments.txt
module_0021.SyncTensorsGraph.4.tpu_comp_env.txt
module_0024.dummy_cos_sin_decored.5.before_optimizations.txt
module_0024.dummy_cos_sin_decored.5.execution_options.txt
module_0024.dummy_cos_sin_decored.5.flagfile
module_0024.dummy_cos_sin_decored.5.hlo_module_config.txt
module_0024.dummy_cos_sin_decored.5.target_arguments.txt
module_0024.dummy_cos_sin_decored.5.tpu_comp_env.txt
…
You can observe the difference between the original and named outputs from the same file by looking at the output above. The automatically produced name is “SyncTensorsGraph.” The renamed file associated with the preceding tiny code example is shown below.
vLLM on TPU (testing)
You can now use TPU as a backend if you serve models on GPUs using vLLM. A memory-efficient and high-throughput inference and serving engine for LLMs is called vLLM. To make model testing on TPU easier, vLLM on TPU maintains the same vLLM interface that developers adore, including direct integration into Hugging Face Model Hub.
It only takes a few configuration adjustments to switch your vLLM endpoint to TPU. Everything is unchanged except for the TPU image: the model source code, load balancing, autoscaling metrics, and the request payload. Refer to the installation guide for further information.
Pallas kernels like paged attention, flash attention, and dynamo bridge speed optimizations are among the other vLLM capabilities it has added to TPU. These are all now included in the PyTorch/XLA repository (code). Although PyTorch TPU users may now access vLLM, this work is still in progress, and it anticipate adding more functionality and improvements in upcoming releases.
Use PyTorch/XLA 2.5
Downloading the most recent version via your Python package manager will allow you to begin utilizing these new capabilities. For installation instructions and more thorough information, see the project’s GitHub page if you’ve never heard of PyTorch/XLA before.