Tag Archives: OpenXLA

A Robust Open Ecosystem for All: Accelerating AI Infrastructure


JAX now runs on AWS Trainium: Open Source Fuels AI Innovation

Open source software is the foundation of machine learning. It accelerates innovation through an ethos of flexibility and collaboration. This philosophy drives the open development of JAX, our high-performance array computing library, as well as OpenXLA, the compiler and runtime infrastructure it relies on.

Today we're excited to highlight how this commitment to openness, together with JAX and OpenXLA's modular designs, enables seamless integration of AWS Trainium and Trainium2 chips accelerators into the JAX ecosystem. Users get more portability, more choice, and faster progress.


JAX and OpenXLA, abstraction and modularity

JAX is a Python library for high-performance, large-scale numerical computing and machine learning. Its unique compiler-oriented design makes numerical computation familiar and portable while also accelerator-friendly and scalable. It combines a NumPy-like API with composable transformations for automatic differentiation, vectorization, parallelization, and more. Under the hood, JAX leverages the XLA compiler to optimize and scale computations over a broad set of backends.

This abstraction layer is key to its portability: JAX presents a consistent interface while XLA optimizes performance, whether you're running on CPUs, GPUs, TPUs, or something new.

In fact, OpenXLA infrastructure is designed to be modular and extensible to new platforms. By developing a PJRT plugin and leveraging existing XLA compiler components, JAX code can target new platforms, even when scaling from a single device to thousands.


Enter AWS Trainium and Inferentia

We are excited to announce that AWS Trainium is the latest platform to embrace JAX and OpenXLA. With the JAX Neuron plugin, AWS Trainium and Inferentia can be used as native JAX devices.

This new backend demonstrates how abstraction and modularity make JAX and OpenXLA especially extensible and amenable to collaboration, even on new hardware. We're thrilled to have diverse hardware partners like AMD, Arm, Intel, Nvidia, and AWS taking advantage of JAX's portability and performance. If you're interested in bringing new platforms into the JAX and OpenXLA ecosystem, please reach out!

A multi-platform ecosystem fosters open collaboration in advancing AI infrastructure. Our goal is to drive continuous development of open standards and to accelerate progress. And if you're a machine learning developer or numerical computing user, we're excited for you to try JAX on any platform you choose.

By Matthew Johnson - Principal Scientist, with additional contributors: Aditi Joshi, Fenghui Zhang, Roy Frostig, and Carlos Araya

PJRT: Simplifying ML Hardware and Framework Integration

Infrastructure fragmentation in Machine Learning (ML) across frameworks, compilers, and runtimes makes developing new hardware and toolchains challenging. This inhibits the industry’s ability to quickly productionize ML-driven advancements. To simplify the growing complexity of ML workload execution across hardware and frameworks, we are excited to introduce PJRT and open source it as part of the recently available OpenXLA Project.

PJRT (used in conjunction with OpenXLA’s StableHLO) provides a hardware- and framework-independent interface for compilers and runtimes. It simplifies the integration of hardware with frameworks, accelerating framework coverage for the hardware, and thus hardware targetability for workload execution.

PJRT is the primary interface for TensorFlow and JAX and fully supported for PyTorch, and is well integrated with the OpenXLA ecosystem to execute workloads on TPU, GPU, and CPU. It is also the default runtime execution path for most of Google’s internal production workloads. The toolchain-independent architecture of PJRT allows it to be leveraged by any hardware, framework, or compiler, with extensibility for unique features. With this open-source release, we're excited to allow anyone to begin leveraging PJRT for their own devices.

If you’re developing an ML hardware accelerator or developing your own compiler and runtime, check out the PJRT source code on GitHub and sign up for the OpenXLA mailing list to quickly bootstrap your work.

Vision: Simplifying ML Hardware and Framework Integration

We are entering a world of ambient experiences where intelligent apps and devices surround us, from edge to the cloud, in a range of environments and scales. ML workload execution currently supports a combinatorial matrix of hardware, frameworks, and workflows, mostly through tight vertical integrations. Examples of such vertical integrations include specific kernels for TPU versus GPU, specific toolchains to train and serve in TensorFlow versus PyTorch. These bespoke 1:1 integrations are perfectly valid solutions but promote lock-in, inhibit innovation, and are expensive to maintain. This problem of a fragmented software stack is compounded over time as different computing hardware needs to be supported.

A variety of ML hardware exists today and hardware diversity is expected to increase in the future. ML users have options and they want to exercise them seamlessly: users want to train a large language model (LLM) on TPU in the Cloud, batch infer on GPU or even CPU, distill, quantize, and finally serve them on mobile processors. Our goal is to solve the challenge of making ML workloads portable across hardware by making it easy to integrate the hardware into the ML infrastructure (framework, compiler, runtime).

Portability: Seamless Execution

The workflow to enable this vision with PJRT is as follows (shown in Figure 1):

  1. The hardware-specific compiler and runtime provider implement the PJRT API, package it as a plugin containing the compiler and runtime hooks, and register it with the frameworks. The implementation can be opaque to the frameworks.
  2. The frameworks discover and load one or multiple PJRT plugins as dynamic libraries targeting the hardware on which to execute the workload.
  3. That’s it! Execute the workload from the framework onto the target hardware.

The PJRT API will be backward compatible. The plugin would not need to change often and would be able to do version-checking for features.

Diagram of PJRT architecture
Figure 1: To target specific hardware, provide an implementation of the PJRT API to package a compiler and runtime plugin that can be called by the framework.

Cohesive Ecosystem

As a foundational pillar of the OpenXLA Project, PJRT is well-integrated with projects within the OpenXLA Project including StableHLO and the OpenXLA compilers (XLA, IREE). It is the primary interface for TensorFlow and JAX and fully supported for PyTorch through PyTorch/XLA. It provides the hardware interface layer in solving the combinatorial framework x hardware ML infrastructure fragmentation (see Figure 2).

Diagram of PJRT hardware interface layer
Figure 2: PJRT provides the hardware interface layer in solving the combinatorial framework x hardware ML infrastructure fragmentation, well-integrated with OpenXLA.

Toolchain Independent

PJRT is hardware and framework independent. With framework integration through the self-contained IR StableHLO, PJRT is not coupled with a specific compiler, and can be used outside of the OpenXLA ecosystem, including with other proprietary compilers. The public availability and toolchain-independent architecture allows it to be used by any hardware, framework or compiler, with extensibility for unique features. If you are developing an ML hardware accelerator, compiler, or runtime targeting any hardware, or converging siloed toolchains to solve infrastructure fragmentation, PJRT can minimize bespoke hardware and framework integration, providing greater coverage and improving time-to-market at lower development cost.

Driving Impact with Collaboration

Industry partners such as Intel and others have already adopted PJRT.

Intel

Intel is leveraging PJRT in Intel® Extension for TensorFlow to provide the Intel GPU backend for TensorFlow and JAX. This implementation is based on the PJRT plugin mechanism (see RFC). Check out how this greatly simplifies the framework and hardware integration with this example of executing a JAX program on Intel GPU.

"At Intel, we share Google's vision of modular interfaces to make integration easier and enable faster, framework-independent development. Similar in design to the PluggableDevice mechanism, PJRT is a pluggable interface that allows us to easily compile and execute XLA's High Level Operations on Intel devices. Its simple design allowed us to quickly integrate it into our systems and start running JAX workloads on Intel® GPUs within just a few months. PJRT enables us to more efficiently deliver hardware acceleration and oneAPI-powered AI software optimizations to developers using a wide range of AI Frameworks." - Wei Li, VP and GM, Artificial Intelligence and Analytics, Intel.

Technology Leader

We’re also working with a technology leader to leverage PJRT to provide the backend targeting their proprietary processor for JAX. More details on this to follow soon.

Get Involved

PJRT is available on GitHub: source code for the API and a reference openxla-pjrt-plugin, and integration guides. If you develop ML frameworks, compilers, or runtimes, or are interested in improving portability of workloads across hardware, we want your feedback. We encourage you to contribute code, design ideas, and feature suggestions. We also invite you to join the OpenXLA mailing list to stay updated with the latest product and community announcements and to help shape the future of an interoperable ML infrastructure.

Acknowledgements

Allen Hutchison, Andrew Leaver, Chuanhao Zhuge, Jack Cao, Jacques Pienaar, Jieying Luo, Penporn Koanantakool, Peter Hawkins, Robert Hundt, Russell Power, Sagarika Chalasani, Skye Wanderman-Milne, Stella Laurenzo, Will Cromar, Xiao Yu.

By Aman Verma, Product Manager, Machine Learning Infrastructure