opensource.google.com

Menu

PJRT: Simplifying ML Hardware and Framework Integration

Monday, May 8, 2023

Infrastructure fragmentation in Machine Learning (ML) across frameworks, compilers, and runtimes makes developing new hardware and toolchains challenging. This inhibits the industry’s ability to quickly productionize ML-driven advancements. To simplify the growing complexity of ML workload execution across hardware and frameworks, we are excited to introduce PJRT and open source it as part of the recently available OpenXLA Project.

PJRT (used in conjunction with OpenXLA’s StableHLO) provides a hardware- and framework-independent interface for compilers and runtimes. It simplifies the integration of hardware with frameworks, accelerating framework coverage for the hardware, and thus hardware targetability for workload execution.

PJRT is the primary interface for TensorFlow and JAX and fully supported for PyTorch, and is well integrated with the OpenXLA ecosystem to execute workloads on TPU, GPU, and CPU. It is also the default runtime execution path for most of Google’s internal production workloads. The toolchain-independent architecture of PJRT allows it to be leveraged by any hardware, framework, or compiler, with extensibility for unique features. With this open-source release, we're excited to allow anyone to begin leveraging PJRT for their own devices.

If you’re developing an ML hardware accelerator or developing your own compiler and runtime, check out the PJRT source code on GitHub and sign up for the OpenXLA mailing list to quickly bootstrap your work.

Vision: Simplifying ML Hardware and Framework Integration

We are entering a world of ambient experiences where intelligent apps and devices surround us, from edge to the cloud, in a range of environments and scales. ML workload execution currently supports a combinatorial matrix of hardware, frameworks, and workflows, mostly through tight vertical integrations. Examples of such vertical integrations include specific kernels for TPU versus GPU, specific toolchains to train and serve in TensorFlow versus PyTorch. These bespoke 1:1 integrations are perfectly valid solutions but promote lock-in, inhibit innovation, and are expensive to maintain. This problem of a fragmented software stack is compounded over time as different computing hardware needs to be supported.

A variety of ML hardware exists today and hardware diversity is expected to increase in the future. ML users have options and they want to exercise them seamlessly: users want to train a large language model (LLM) on TPU in the Cloud, batch infer on GPU or even CPU, distill, quantize, and finally serve them on mobile processors. Our goal is to solve the challenge of making ML workloads portable across hardware by making it easy to integrate the hardware into the ML infrastructure (framework, compiler, runtime).

Portability: Seamless Execution

The workflow to enable this vision with PJRT is as follows (shown in Figure 1):

  1. The hardware-specific compiler and runtime provider implement the PJRT API, package it as a plugin containing the compiler and runtime hooks, and register it with the frameworks. The implementation can be opaque to the frameworks.
  2. The frameworks discover and load one or multiple PJRT plugins as dynamic libraries targeting the hardware on which to execute the workload.
  3. That’s it! Execute the workload from the framework onto the target hardware.

The PJRT API will be backward compatible. The plugin would not need to change often and would be able to do version-checking for features.

Diagram of PJRT architecture
Figure 1: To target specific hardware, provide an implementation of the PJRT API to package a compiler and runtime plugin that can be called by the framework.

Cohesive Ecosystem

As a foundational pillar of the OpenXLA Project, PJRT is well-integrated with projects within the OpenXLA Project including StableHLO and the OpenXLA compilers (XLA, IREE). It is the primary interface for TensorFlow and JAX and fully supported for PyTorch through PyTorch/XLA. It provides the hardware interface layer in solving the combinatorial framework x hardware ML infrastructure fragmentation (see Figure 2).

Diagram of PJRT hardware interface layer
Figure 2: PJRT provides the hardware interface layer in solving the combinatorial framework x hardware ML infrastructure fragmentation, well-integrated with OpenXLA.

Toolchain Independent

PJRT is hardware and framework independent. With framework integration through the self-contained IR StableHLO, PJRT is not coupled with a specific compiler, and can be used outside of the OpenXLA ecosystem, including with other proprietary compilers. The public availability and toolchain-independent architecture allows it to be used by any hardware, framework or compiler, with extensibility for unique features. If you are developing an ML hardware accelerator, compiler, or runtime targeting any hardware, or converging siloed toolchains to solve infrastructure fragmentation, PJRT can minimize bespoke hardware and framework integration, providing greater coverage and improving time-to-market at lower development cost.

Driving Impact with Collaboration

Industry partners such as Intel and others have already adopted PJRT.

Intel

Intel is leveraging PJRT in Intel® Extension for TensorFlow to provide the Intel GPU backend for TensorFlow and JAX. This implementation is based on the PJRT plugin mechanism (see RFC). Check out how this greatly simplifies the framework and hardware integration with this example of executing a JAX program on Intel GPU.

"At Intel, we share Google's vision of modular interfaces to make integration easier and enable faster, framework-independent development. Similar in design to the PluggableDevice mechanism, PJRT is a pluggable interface that allows us to easily compile and execute XLA's High Level Operations on Intel devices. Its simple design allowed us to quickly integrate it into our systems and start running JAX workloads on Intel® GPUs within just a few months. PJRT enables us to more efficiently deliver hardware acceleration and oneAPI-powered AI software optimizations to developers using a wide range of AI Frameworks." - Wei Li, VP and GM, Artificial Intelligence and Analytics, Intel.

Technology Leader

We’re also working with a technology leader to leverage PJRT to provide the backend targeting their proprietary processor for JAX. More details on this to follow soon.

Get Involved

PJRT is available on GitHub: source code for the API and a reference openxla-pjrt-plugin, and integration guides. If you develop ML frameworks, compilers, or runtimes, or are interested in improving portability of workloads across hardware, we want your feedback. We encourage you to contribute code, design ideas, and feature suggestions. We also invite you to join the OpenXLA mailing list to stay updated with the latest product and community announcements and to help shape the future of an interoperable ML infrastructure.

Acknowledgements

Allen Hutchison, Andrew Leaver, Chuanhao Zhuge, Jack Cao, Jacques Pienaar, Jieying Luo, Penporn Koanantakool, Peter Hawkins, Robert Hundt, Russell Power, Sagarika Chalasani, Skye Wanderman-Milne, Stella Laurenzo, Will Cromar, Xiao Yu.

By Aman Verma, Product Manager, Machine Learning Infrastructure

.