GCP – PyTorch/XLA 2.5: vLLM support and an improved developer experience
Machine learning engineers are bullish on PyTorch/XLA, a Python package that uses the XLA deep learning compiler to connect the PyTorch deep learning framework and Cloud TPUs. And now, PyTorch/XLA 2.5 is here, along with a set of improvements to add support for vLLM and enhance the overall developer experience. Featured in this release are:
-
A clarified proposal for deprecation of the older torch_xla API in favor of moving towards the existing PyTorch API, providing for a simplified developer experience. An example of this is the migration of existing Distributed API.
-
A series of improvements to the torch_xla.compile function which improve the debugging experience for developers during the development process.
-
Experimental support in vLLM for TPUs, allowing you to extend your existing deployments and while leveraging the same vLLM interface across your TPUs.
Let’s take a look at each of these enhancements.
Streamlining the torch_xla API
With PyTorch/XLA 2.5, we’re taking a significant step towards making the API more consistent with upstream PyTorch. Our north star is to minimize the learning curve for developers already familiar with PyTorch, making it easier to use XLA devices. This means gradually phasing out and deprecating custom API calls for PyTorch/XLA for more mature functionality when possible, and then, migrating the API calls over to their PyTorch counterparts. Other features still remain within the existing Python module before migration.
In the spirit of a simpler developer experience for PyTorch/XLA, in this release we have migrated over to leveraging some existing PyTorch distributed API functions when running models on top of PyTorch/XLA. Historically, the calls for the distributed API were located under the torch_xla module; in this update we migrated most of them to torch.distributed.
- code_block
- <ListValue: [StructValue([(‘code’, ‘# With PyTorch/XLA 2.4rnimport torch_xla.core.xla_model as xmrnxm.all_reduce()rnrn# Supported after PyTorch/XLA 2.5rntorch.distrbuted.all_reduce()’), (‘language’, ”), (‘caption’, <wagtail.rich_text.RichText object at 0x3ee6a15234c0>)])]>
- aside_block
- <ListValue: [StructValue([(‘title’, ‘$300 in free credit to try Google Cloud AI and ML’), (‘body’, <wagtail.rich_text.RichText object at 0x3ee6a0d901c0>), (‘btn_text’, ‘Start building for free’), (‘href’, ‘http://console.cloud.google.com/freetrial?redirectPath=/vertex-ai/’), (‘image’, None)])]>
Improvement to ‘torch_xla.compile’
We’ve also added a few new compilation features to help you debug or notice potential issues within your model code. For example, a ‘full_graph’ mode emits an error message when there’s more than one compilation graph. This helps you discover potential issues caused by multiple compilation graphs early on (during compilation).
Additionally, you can now specify an expected number of recompilations for compiled functions. This can help you debug performance issues in which a function might be getting recompiled more times than expected, for example, when it has unexpected dynamism.
You can now also give compiled functions an understandable name instead of an automatically created one. By naming compiled targets, you gain more context when debugging messages, making it easier to figure out where the problem may be. Here’s an example of what that looks like in reality:
- code_block
- <ListValue: [StructValue([(‘code’, ‘# named codern@torch_xla.compilerndef dummy_cos_sin_decored(self, tensor):rn return torch.cos(torch.sin(tensor))rnrn# target dumped HLO renamed with named code function name rn…rnmodule_0021.SyncTensorsGraph.4.hlo_module_config.txtrnmodule_0021.SyncTensorsGraph.4.target_arguments.txtrnmodule_0021.SyncTensorsGraph.4.tpu_comp_env.txtrnmodule_0024.dummy_cos_sin_decored.5.before_optimizations.txtrnmodule_0024.dummy_cos_sin_decored.5.execution_options.txtrnmodule_0024.dummy_cos_sin_decored.5.flagfilernmodule_0024.dummy_cos_sin_decored.5.hlo_module_config.txtrnmodule_0024.dummy_cos_sin_decored.5.target_arguments.txtrnmodule_0024.dummy_cos_sin_decored.5.tpu_comp_env.txtrn…’), (‘language’, ”), (‘caption’, <wagtail.rich_text.RichText object at 0x3ee6a1523730>)])]>
Looking at the above output you can see the original versus the named output generated from the same file; ‘SyncTensorsGraph’ is the automatically generated name. Below, you can see the renamed file related to the small code example above.
vLLM on TPU (experimental)
If you use vLLM to serve models on GPUs, you can now switch to TPU as a backend. vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. vLLM on TPU retains the same vLLM interface that developers love, including direct integration into Hugging Face Model Hub to simplify model experimentation on TPU.
Switching your vLLM endpoint to TPU is a matter of a few config changes. Aside from the TPU image, everything else remains the same: request payload, metrics used for autoscaling, load balancing, model source code, etc. For details, see the installation guide.
Other vLLM features we’ve extended to TPU include Pallas kernels such as paged attention, flash attention and performance optimizations in dynamo bridge, all which are now part of the PyTorch/XLA repository (code). While vLLM is available to PyTorch TPU users, this work is still ongoing, and we look forward to rolling out additional features and optimizations in future releases.
Start using PyTorch/XLA 2.5
You can start taking advantage of these latest features by downloading the latest release through your Python package manager. Or, if this is your first time hearing about PyTorch/XLA, check out the project’s Github page for installation instructions and more detailed information.
For a full list of changes, check out the release notes!
Read More for the details.