Tag Archives: Machine Learning Framework

PJRT Plugin to Accelerate Machine Learning

PJRT is an open, stable interface for device runtime and compiler, which simplifies ML hardware and framework integration. With PJRT, ML frameworks become hardware-agnostic and ML hardware becomes pluggable. For the ML developer, it simplifies the adoption of new ML hardware and models become more portable. This addresses ML infrastructure fragmentation across frameworks, compilers and runtimes enhancing the industry’s ability to productionize ML-driven advancements with velocity and at scale.

This article provides an overview of what building a PJRT plugin entails, how frameworks (and models) can use this plugin, and some updates on the PJRT API. PJRT is now used by a growing spectrum of hardware: Apple silicon, Google Cloud TPU, NVIDIA GPU, and Intel Max GPU. We also share a spotlight on Apple’s adoption of PJRT with some details on the workflow and performance.

If you’re developing an ML hardware accelerator or developing your own compiler and runtime, check out the PJRT source code on GitHub and sign up for the PJRT mailing list to quickly bootstrap your work.

What’s in a PJRT Plugin

PJRT was introduced to simplify the growing complexity of ML workload execution across hardware and frameworks. PJRT (used in conjunction with StableHLO) is a stable interface for device runtime and compiler, which abstracts away device specific implementations from frameworks.

An implementation of the PJRT API is called a PJRT plugin, which is usually a Python package for seamless ML model developer experience. To build a PJRT plugin for a hardware target, the following methods need to be implemented:

  • Compile: compile (program) -> executable
  • Runtime: execute (executable, arguments) -> results
  • Memory management: transfer buffer from host to device, device to host, device to device, as well as buffer management such as buffer donation
  • Topology information such as the platform, how many accelerators and how are they attached.

ML frameworks will discover and load one or multiple PJRT plugins, and call the PJRT API to compile and execute the model. The PJRT plugins may be required to register to the ML frameworks depending on the specific discovery mechanism the framework uses.

API Updates

Versioning and ABI Compatibility

PJRT API has a major version and a minor version. If the framework is newer than the plugin, the framework provides a N-week (N=6 today) forwards compatibility window for minor version updates. The major version updates will be a coordinated update. Frameworks will not support plugins with a lower major version. If the plugin is newer than the framework, plugins will define their own backward compatibility policy.


A PJRT client is per node, and the plugin may need some way to communicate among nodes in a distributed workload. The framework can pass in key-value store callbacks to the plugin. The plugin can use them to bootstrap multi-node initialization and other coordination needs. An example with the NVIDIA GPU CUDA plugin is as follows:

  • JAX starts a distribution service and provides key-value store callbacks.
  • NVIDIA GPU CUDA plugin uses these callbacks to (1) generate global PJRT device topology that includes PJRT device information from all nodes, and (2) generate NCCL ids.


A few C APIs were added to PJRT to support DLPack.

  • PJRT_Client_CreateViewOfDeviceBuffer supports receiving buffers from DLPack.
  • Exporting buffers to DLPack requires: PJRT_Buffer_IncreaseExternalReferenceCount, PJRT_Buffer_DecreaseExternalReferenceCount to get a PJRT_Buffer_OpaqueDeviceMemoryDataPointer.


PJRT API provides an extension mechanism that the plugin can provide extensions which are optional or experimental features. These extensions can have their own compatibility guarantee and do not need to support the ABI compatibility of PJRT API.

Industry Adoption

PJRT is the only interface for JAX, the primary interface for TensorFlow and fully supported for PyTorch through PyTorch/XLA. PJRT is not tied to a specific compiler and runtime. The toolchain-independent architecture and open-source availability as part of the OpenXLA Project allows it to be leveraged by any hardware, framework or compiler, with extensibility for unique features. This has allowed PJRT to be adopted by various industry partners through close collaboration. A brief account of Apple’s adoption of PJRT follows.

JAX on Apple Silicon

Apple’s PJRT plugin for the Metal training backend accelerates JAX models on Apple silicon and AMD GPUs. This empowers any ML developers to leverage the full potential of Apple silicon and AMD GPUs on their Apple hardware to accelerate JAX models for faster experimentation. The integration and user experience to accelerate JAX on Apple silicon GPUs is similar to the existing PyTorch and TensorFlow implementations.

The Metal plug-in uses the OpenXLA compiler and PJRT runtime to optimize and accelerate JAX workloads on GPU. When a JAX program is executed, the JAX graph is lowered into StableHLO, which is then passed to PJRT for compilation and execution. The StableHLO is converted to MPSGraph executables and the Metal runtime APIs are invoked to dispatch to the GPU.


The Metal backend with PJRT plugin provides impressive performance speedup for JAX. On an Apple MacBook Pro with M2 Max, training common networks in JAX see performance speedups of up to 28x, with an average of 10x over a CPU baseline. This empowers any ML developer to leverage the full potential of Apple Silicon on their Apple hardware to accelerate JAX models for faster experimentation.

graph of performance speedups of up to 28x on Apple MacBook Pro with M2 Max over CPU for JAX training.
Figure 1: Performance speedups of up to 28x on Apple MacBook Pro with M2 Max over CPU for JAX training.

Getting Started

Adding Metal support to JAX is as simple as a single pip install:

python -m pip install jax-metal
python -c 'import jax; print(jax.numpy.arange(10))'

For more details on environment setup and installation of JAX on Apple hardware, please refer to the Metal Developer Resources page.

Google Cloud TPU

PJRT is the default runtime for PyTorch 2.0 on Google Cloud TPU. GitHub Readme has more details.


The NVIDIA GPU CUDA implementation in JAX is extracted and packaged as a PJRT plugin. The ML model developers can install the NVIDIA GPU CUDA plugin from pypi. This plugin uses the newly added features such as multi-node, DLPack, and extensions.

Intel GPU

Intel is leveraging PJRT in Intel® Extension for TensorFlow to provide the Intel GPU backend for TensorFlow, JAX and PyTorch. The example of executing a JAX program on Intel GPU demonstrates how this greatly simplifies the framework and hardware integration.

PJRT Resources

PJRT is available on GitHub: source code for the API, integration guides and issues. If you develop ML frameworks, compilers, runtimes or are interested in improving portability of workloads across hardware, we want your feedback. We encourage you to contribute code, design ideas and feature suggestions. We also invite you to join the PJRT mailing list to stay updated with the latest product and community announcements and to help shape the future of an interoperable ML infrastructure.


Chalana Bezawada, Daniel Doctor, Kulin Seth, Shuhan Ding from Apple 
Penporn Koanantakool, Peter Hawkins, Skye Wanderman-Milne, Xiao Yu from Google.

By Aman Verma – Product Manager, Machine Learning Infrastructure, Google and Jieying Luo – Software Engineer, Machine Learning Infrastructure, Google

Kubeflow joins the CNCF family

We are thrilled to announce a major milestone in the journey of the Kubeflow project. After a comprehensive review process and several months of meticulous preparation, Kubeflow has been accepted by the Cloud Native Computing Foundation (CNCF) as an incubating project. This momentous step marks a new chapter in our collaborative and open approach to accelerating machine learning (ML) in the cloud native ecosystem.

The acceptance of Kubeflow into the incubation stage by the CNCF reflects not just the project's maturity, but also its widespread adoption and expanding user base. It underscores the tremendous value of the diverse suite of components that Kubeflow provides, including Notebooks, Pipelines, Training Operators, Katib, Central Dashboard, Manifests, and many more. These tools have been instrumental in creating a cohesive, end-to-end ML platform that streamlines the development and deployment of ML workflows.

Furthermore, the alignment of Kubeflow with the CNCF acknowledges the project's foundational reliance on several existing CNCF projects such as Argo, Cert-Manager, and Istio. The joining of Kubeflow with the CNCF will serve to strengthen these existing relationships and foster greater collaboration among cloud native projects, leading to even more robust and innovative solutions for users.

Looking ahead, Google and the Kubeflow community are eager to collaborate with the CNCF on the transition process. Rest assured, our commitment to Kubeflow's ongoing development remains unwavering during this transition. We will continue to support new feature development, plan and execute upcoming releases, and strive to deliver further improvements to the Kubeflow project.

We extend our heartfelt thanks to the CNCF Technical Oversight Committee and the wider CNCF community for their support and recognition of the Kubeflow project. We look forward to this exciting new phase in our shared journey towards advancing machine learning in the cloud native landscape.

As Kubeflow continues to evolve, we invite developers, data scientists, ML engineers, and all other interested individuals to join us in shaping the future of cloud native machine learning. Let's innovate together, with Kubeflow and the CNCF, to make machine learning workflows more accessible, manageable, and scalable than ever before!

By James Liu – GCP Cloud AI

Kubeflow applies to become a CNCF incubating project

Google has pioneered AI and ML and has a history of innovative technology donations to the open source community (e.g. TensorFlow and Jax). Google is also the initial developer and largest contributor to Kubernetes, and brings with it a wealth of experience to the project and its community. Building an ML Platform on our state-of-the-art Google Kubernetes Engine (GKE), we have learned best practices from our users, and in 2017, we used that experience to create and open source the Kubeflow project.

In May 2020, with the v1.0 release, Kubeflow reached maturity across a core set of its stable applications. During that year, we also graduated Kubeflow Serving as an independent project, KServe, which is now incubating in Linux Foundation AI & Data.

Today, Kubeflow has developed into an end-to-end, extendable ML platform, with multiple distinct components to address specific stages of the ML lifecycle: model development (Kubeflow Notebooks), model training (Kubeflow Pipelines and Kubeflow Training Operator), model serving (KServe), and automated machine learning (Katib).

The Kubeflow project now has close to 200 contributors from over 30 organizations, and the Kubeflow community has hosted several summits and contributor meetups across the world. The broader Kubeflow ecosystem includes a number distributions across multiple cloud service providers and on-prem environments. Kubeflow’s powerful development experience helps data scientists build, train, and deploy their ML models, enabling enterprise ML operation teams to deploy and scale advanced workflows in a variety of infrastructures.

Google’s application for Kubeflow to become a CNCF incubating project is the next big milestone for the Kubeflow community, and we’re thrilled to see how developers will continue to build and innovate in ML using this project.

What's next? The pull request we’ve opened today to join the CNCF as an incubating project is only the first step. Google and the Kubeflow community will work with the CNCF and their Technical Oversight Committee (TOC), to meet the incubation stage requirements. While the due diligence and eventual TOC decision can take a few months, the Kubeflow project will continue developing and releasing throughout this process.

If Kubeflow is accepted into CNCF, the project’s assets will be transferred to the CNCF, including the source code, trademark, and website, and other collaboration and social media accounts. At Google, we believe that using open source comes with a responsibility to contribute, maintain, and improve those projects. In that spirit, we will continue supporting the Kubeflow project and work with the community towards the next level of innovation.

Thanks to everyone who has contributed to Kubeflow over the years! We are excited for what lays ahead for the Kubeflow community.

By Thea Lamkin, Senior Program Manager and Mark Chmarny, Senior Technical Program Manager – Google Open Source

Accelerate your models to production with Google Cloud and PyTorch

We believe in the power of choice for Machine Learning development, and continue to invest resources to make it easy for ML practitioners to train, deploy, and orchestrate models from a single unified data and AI cloud platform. We’re excited to announce our role as a founding member of the newly formed PyTorch Foundation, which will better position Google Cloud to make meaningful contributions to the PyTorch community. As a member of the board, we will deepen our open source investment to deliver on the Foundation’s mission to drive adoption of AI tooling by building an ecosystem of open source projects with PyTorch. We strongly believe in choice and will continue to invest in frameworks such as JAX and Tensorflow and support integrations with other OSS Projects including Spark, Airflow, XGBoost, and others.

In this blog, we provide an overview of existing resources to help you get started with PyTorch on Google Cloud. We also talk about how ML practitioners can leverage our end-to-end ML platform to train, tune, and deploy PyTorch models.

PyTorch on Google Cloud

Open source in the cloud is important because it gives you flexibility and control over where you train and deploy your ML workloads. PyTorch is extensively used in the research space and in recent years it has gained immense traction in the industry due to its ease of use and deployment. In fact, according to a survey of Kaggle users, PyTorch is the fastest growing ML framework today.

ML practitioners using PyTorch tell us that it can be challenging to advance their ML project past experimentation. This is why Google Cloud has built integrations with PyTorch that make it easier to train, deploy, and orchestrate models in production. Some examples are:

  • PyTorch integrates directly with Vertex AI, a fully managed ML platform that provides the tools you need to take a model from PyTorch to production, like the Pytorch DL containers or the Vertex AI workbench PyTorch one-click JupyterLab environment.
  • PyTorch/XLA, an open source library, uses the XLA deep learning compiler to enable PyTorch to run on Cloud TPUs. Cloud TPUs are custom accelerators designed by Google, optimized for perf/TCO with large scale ML workload PyTorch/XLA also enables XLA driven optimizations on GPUs.
  • TorchX provides an adapter to run and orchestrate TorchX components as part of Kubeflow Pipelines that you can easily scale on Vertex AI Pipelines.
  • With our OSS contributions to Apache Beam, we have made PyTorch models easy to deploy in batch or stream, data processing pipelines. Running on Google Dataflow, these pipelines will scale to very large workloads in a fully managed and simple to maintain environment.

To learn more and start using PyTorch on Google Cloud, check out the resources below:

PyTorch on Vertex AI Resources

  1. How To train and tune PyTorch models on Vertex AI: Learn how to use Vertex AI Training to build and train a sentiment text classification model using PyTorch and Vertex AI Hyperparameter Tuning to tune hyperparameters of PyTorch models.
  2. How to deploy PyTorch models on Vertex AI: Walk through the deployment of a Pytorch model using TorchServe as a custom container, by deploying the model artifacts to a Vertex Prediction service.
  3. Orchestrating PyTorch ML Workflows on Vertex AI Pipelines: See how to build and orchestrate ML pipelines for training and deploying PyTorch models on Google Cloud Vertex AI using Vertex AI Pipelines.
  4. Scalable ML Workflows using PyTorch on Kubeflow Pipelines and Vertex Pipelines: Take a look at examples of PyTorch-based ML workflows on two pipelines frameworks: OSS Kubeflow Pipelines, part of the Kubeflow project, and Vertex AI Pipelines. We share new PyTorch built-in components added to the Kubeflow Pipelines.

PyTorch/XLA and Cloud TPU/GPU

  1. Scaling deep learning workloads with PyTorch / XLA and Cloud TPU VM: Describes the challenges associated with scaling deep learning jobs to distributed training settings, using the Cloud TPU VM and shows how to stream training data from Google Cloud Storage (GCS) to PyTorch / XLA models running on Cloud TPU Pod slices.
  2. PyTorch/XLA: Performance debugging on Cloud TPU VM: Part I: In the first part of the performance debugging series on Cloud TPU, we lay out the conceptual framework for PyTorch/XLA in the context of training performance. We introduced a case study to make sense of preliminary profiler logs and identify the corrective actions.
  3. PyTorch/XLA: Performance debugging on Cloud TPU VM: Part II: In the second part, we deep dive into further analysis of the performance debugging to discover more performance improvement opportunities.
  4. PyTorch/XLA: Performance debugging on Cloud TPU VM: Part III: In the final part of the performance debugging series, we introduce user defined code annotation and visualize these annotations in the form of a trace.
  5. Train ML models with Pytorch Lightning on TPUs: Learn how easy it is to start training models with PyTorch Lightning on TPUs with its built-in TPU support.

Other resources

  1. Increase your productivity using PyTorch Lightning: Learn how to use PyTorch Lightning on Vertex AI Workbench (was previously Notebooks).

By Erwin Huizing and Grace Reed – Cloud AI and ML