JAX now runs on AWS Trainium: Open Source Fuels AI Innovation
Open source software is the foundation of machine learning. It accelerates innovation through an ethos of flexibility and collaboration. This philosophy drives the open development of JAX, our high-performance array computing library, as well as OpenXLA, the compiler and runtime infrastructure it relies on.
Today we're excited to highlight how this commitment to openness, together with JAX and OpenXLA's modular designs, enables seamless integration of AWS Trainium and Trainium2 chips accelerators into the JAX ecosystem. Users get more portability, more choice, and faster progress.
JAX and OpenXLA, abstraction and modularity
JAX is a Python library for high-performance, large-scale numerical computing and machine learning. Its unique compiler-oriented design makes numerical computation familiar and portable while also accelerator-friendly and scalable. It combines a NumPy-like API with composable transformations for automatic differentiation, vectorization, parallelization, and more. Under the hood, JAX leverages the XLA compiler to optimize and scale computations over a broad set of backends.
This abstraction layer is key to its portability: JAX presents a consistent interface while XLA optimizes performance, whether you're running on CPUs, GPUs, TPUs, or something new.
In fact, OpenXLA infrastructure is designed to be modular and extensible to new platforms. By developing a PJRT plugin and leveraging existing XLA compiler components, JAX code can target new platforms, even when scaling from a single device to thousands.
Enter AWS Trainium and Inferentia
We are excited to announce that AWS Trainium is the latest platform to embrace JAX and OpenXLA. With the JAX Neuron plugin, AWS Trainium and Inferentia can be used as native JAX devices.
This new backend demonstrates how abstraction and modularity make JAX and OpenXLA especially extensible and amenable to collaboration, even on new hardware. We're thrilled to have diverse hardware partners like AMD, Arm, Intel, Nvidia, and AWS taking advantage of JAX's portability and performance. If you're interested in bringing new platforms into the JAX and OpenXLA ecosystem, please reach out!
A multi-platform ecosystem fosters open collaboration in advancing AI infrastructure. Our goal is to drive continuous development of open standards and to accelerate progress. And if you're a machine learning developer or numerical computing user, we're excited for you to try JAX on any platform you choose.
By Matthew Johnson - Principal Scientist, with additional contributors: Aditi Joshi, Fenghui Zhang, Roy Frostig, and Carlos Araya