PJRT is an open, stable interface for device runtime and compiler, which simplifies ML hardware and framework integration. With PJRT, ML frameworks become hardware-agnostic and ML hardware becomes pluggable. For the ML developer, it simplifies the adoption of new ML hardware and models become more portable. This addresses ML infrastructure fragmentation across frameworks, compilers and runtimes enhancing the industry’s ability to productionize ML-driven advancements with velocity and at scale.
This article provides an overview of what building a PJRT plugin entails, how frameworks (and models) can use this plugin, and some updates on the PJRT API. PJRT is now used by a growing spectrum of hardware: Apple silicon, Google Cloud TPU, NVIDIA GPU, and Intel Max GPU. We also share a spotlight on Apple’s adoption of PJRT with some details on the workflow and performance.
If you’re developing an ML hardware accelerator or developing your own compiler and runtime, check out the PJRT source code on GitHub and sign up for the PJRT mailing list to quickly bootstrap your work.
What’s in a PJRT Plugin
PJRT was introduced to simplify the growing complexity of ML workload execution across hardware and frameworks. PJRT (used in conjunction with StableHLO) is a stable interface for device runtime and compiler, which abstracts away device specific implementations from frameworks.
An implementation of the PJRT API is called a PJRT plugin, which is usually a Python package for seamless ML model developer experience. To build a PJRT plugin for a hardware target, the following methods need to be implemented:
- Compile: compile (program) -> executable
- Runtime: execute (executable, arguments) -> results
- Memory management: transfer buffer from host to device, device to host, device to device, as well as buffer management such as buffer donation
- Topology information such as the platform, how many accelerators and how are they attached.
ML frameworks will discover and load one or multiple PJRT plugins, and call the PJRT API to compile and execute the model. The PJRT plugins may be required to register to the ML frameworks depending on the specific discovery mechanism the framework uses.
API Updates
Versioning and ABI Compatibility
PJRT API has a major version and a minor version. If the framework is newer than the plugin, the framework provides a N-week (N=6 today) forwards compatibility window for minor version updates. The major version updates will be a coordinated update. Frameworks will not support plugins with a lower major version. If the plugin is newer than the framework, plugins will define their own backward compatibility policy.
Multi-Node
A PJRT client is per node, and the plugin may need some way to communicate among nodes in a distributed workload. The framework can pass in key-value store callbacks to the plugin. The plugin can use them to bootstrap multi-node initialization and other coordination needs. An example with the NVIDIA GPU CUDA plugin is as follows:
- JAX starts a distribution service and provides key-value store callbacks.
- NVIDIA GPU CUDA plugin uses these callbacks to (1) generate global PJRT device topology that includes PJRT device information from all nodes, and (2) generate NCCL ids.
DLPack
A few C APIs were added to PJRT to support DLPack.
- PJRT_Client_CreateViewOfDeviceBuffer supports receiving buffers from DLPack.
- Exporting buffers to DLPack requires: PJRT_Buffer_IncreaseExternalReferenceCount, PJRT_Buffer_DecreaseExternalReferenceCount to get a PJRT_Buffer_OpaqueDeviceMemoryDataPointer.
Extension
PJRT API provides an extension mechanism that the plugin can provide extensions which are optional or experimental features. These extensions can have their own compatibility guarantee and do not need to support the ABI compatibility of PJRT API.
Industry Adoption
PJRT is the only interface for JAX, the primary interface for TensorFlow and fully supported for PyTorch through PyTorch/XLA. PJRT is not tied to a specific compiler and runtime. The toolchain-independent architecture and open-source availability as part of the OpenXLA Project allows it to be leveraged by any hardware, framework or compiler, with extensibility for unique features. This has allowed PJRT to be adopted by various industry partners through close collaboration. A brief account of Apple’s adoption of PJRT follows.
JAX on Apple Silicon
Apple’s PJRT plugin for the Metal training backend accelerates JAX models on Apple silicon and AMD GPUs. This empowers any ML developers to leverage the full potential of Apple silicon and AMD GPUs on their Apple hardware to accelerate JAX models for faster experimentation. The integration and user experience to accelerate JAX on Apple silicon GPUs is similar to the existing PyTorch and TensorFlow implementations.
The Metal plug-in uses the OpenXLA compiler and PJRT runtime to optimize and accelerate JAX workloads on GPU. When a JAX program is executed, the JAX graph is lowered into StableHLO, which is then passed to PJRT for compilation and execution. The StableHLO is converted to MPSGraph executables and the Metal runtime APIs are invoked to dispatch to the GPU.
Performance
The Metal backend with PJRT plugin provides impressive performance speedup for JAX. On an Apple MacBook Pro with M2 Max, training common networks in JAX see performance speedups of up to 28x, with an average of 10x over a CPU baseline. This empowers any ML developer to leverage the full potential of Apple Silicon on their Apple hardware to accelerate JAX models for faster experimentation.
Figure 1: Performance speedups of up to 28x on Apple MacBook Pro with M2 Max over CPU for JAX training. |
Getting Started
Adding Metal support to JAX is as simple as a single pip install:
python -m pip install jax-metal
python -c 'import jax; print(jax.numpy.arange(10))'
For more details on environment setup and installation of JAX on Apple hardware, please refer to the Metal Developer Resources page.
Google Cloud TPU
PJRT is the default runtime for PyTorch 2.0 on Google Cloud TPU. GitHub Readme has more details.
NVIDIA GPU
The NVIDIA GPU CUDA implementation in JAX is extracted and packaged as a PJRT plugin. The ML model developers can install the NVIDIA GPU CUDA plugin from pypi. This plugin uses the newly added features such as multi-node, DLPack, and extensions.
Intel GPU
Intel is leveraging PJRT in Intel® Extension for TensorFlow to provide the Intel GPU backend for TensorFlow, JAX and PyTorch. The example of executing a JAX program on Intel GPU demonstrates how this greatly simplifies the framework and hardware integration.
PJRT Resources
PJRT is available on GitHub: source code for the API, integration guides and issues. If you develop ML frameworks, compilers, runtimes or are interested in improving portability of workloads across hardware, we want your feedback. We encourage you to contribute code, design ideas and feature suggestions. We also invite you to join the PJRT mailing list to stay updated with the latest product and community announcements and to help shape the future of an interoperable ML infrastructure.
Acknowledgements
By Aman Verma – Product Manager, Machine Learning Infrastructure, Google and Jieying Luo – Software Engineer, Machine Learning Infrastructure, Google