Tag Archives: intermediate

Get ready for Google I/O: Program lineup revealed

Posted by Timothy Jordan – Director, Developer Relations and Open Source

Developers, get ready! Google I/O is just around the corner, kicking off live from Mountain View with the Google keynote on Tuesday, May 14 at 10 am PT, followed by the Developer keynote at 1:30 pm PT.

But the learning doesn’t stop there. Mark your calendars for May 16 at 8 am PT when we’ll be releasing over 150 technical deep dives, demos, codelabs, and more on-demand. If you register online, you can start building your 'My I/O' agenda today.

Here's a sneak peek at some of the exciting highlights from the I/O program preview:

Unlocking the power of AI: The Gemini era unlocks a new frontier for developers. We'll showcase the newest features in the Gemini API, Google AI Studio, and Gemma. Discover cutting-edge pre-trained models from Kaggle, and delve into Google's open-source libraries like Keras and JAX.

Android: A developer's playground: Get the latest updates on everything Android! We'll cover groundbreaking advancements in generative AI, the highly anticipated Android 15, innovative form factors, and the latest tools and libraries in the Jetpack and Compose ecosystem. Plus, discover how to optimize performance and streamline your development workflow.

Building beautiful and functional web experiences: We’ll cover Baseline updates, a revolutionary tool that empowers developers with a clear understanding of web features and API interoperability. With Baseline, you'll have access to real-time information on popular developer resource sites like MDN, Can I Use, and web.dev.

The future of ChromeOS: Get a glimpse into the exciting future of ChromeOS. We'll discuss the developer-centric investments we're making in distribution, app capabilities, and operating system integrations. Discover how our partners are shaping the future of Chromebooks and delivering world-class user experiences.

This is just a taste of what's in store at Google I/O. Stay tuned for more updates, and get ready to be a part of the future.

Don't forget to mark your calendars and register for Google I/O today!

PJRT Plugin to Accelerate Machine Learning

PJRT is an open, stable interface for device runtime and compiler, which simplifies ML hardware and framework integration. With PJRT, ML frameworks become hardware-agnostic and ML hardware becomes pluggable. For the ML developer, it simplifies the adoption of new ML hardware and models become more portable. This addresses ML infrastructure fragmentation across frameworks, compilers and runtimes enhancing the industry’s ability to productionize ML-driven advancements with velocity and at scale.

This article provides an overview of what building a PJRT plugin entails, how frameworks (and models) can use this plugin, and some updates on the PJRT API. PJRT is now used by a growing spectrum of hardware: Apple silicon, Google Cloud TPU, NVIDIA GPU, and Intel Max GPU. We also share a spotlight on Apple’s adoption of PJRT with some details on the workflow and performance.

If you’re developing an ML hardware accelerator or developing your own compiler and runtime, check out the PJRT source code on GitHub and sign up for the PJRT mailing list to quickly bootstrap your work.

What’s in a PJRT Plugin

PJRT was introduced to simplify the growing complexity of ML workload execution across hardware and frameworks. PJRT (used in conjunction with StableHLO) is a stable interface for device runtime and compiler, which abstracts away device specific implementations from frameworks.

An implementation of the PJRT API is called a PJRT plugin, which is usually a Python package for seamless ML model developer experience. To build a PJRT plugin for a hardware target, the following methods need to be implemented:

  • Compile: compile (program) -> executable
  • Runtime: execute (executable, arguments) -> results
  • Memory management: transfer buffer from host to device, device to host, device to device, as well as buffer management such as buffer donation
  • Topology information such as the platform, how many accelerators and how are they attached.

ML frameworks will discover and load one or multiple PJRT plugins, and call the PJRT API to compile and execute the model. The PJRT plugins may be required to register to the ML frameworks depending on the specific discovery mechanism the framework uses.

API Updates

Versioning and ABI Compatibility

PJRT API has a major version and a minor version. If the framework is newer than the plugin, the framework provides a N-week (N=6 today) forwards compatibility window for minor version updates. The major version updates will be a coordinated update. Frameworks will not support plugins with a lower major version. If the plugin is newer than the framework, plugins will define their own backward compatibility policy.

Multi-Node

A PJRT client is per node, and the plugin may need some way to communicate among nodes in a distributed workload. The framework can pass in key-value store callbacks to the plugin. The plugin can use them to bootstrap multi-node initialization and other coordination needs. An example with the NVIDIA GPU CUDA plugin is as follows:

  • JAX starts a distribution service and provides key-value store callbacks.
  • NVIDIA GPU CUDA plugin uses these callbacks to (1) generate global PJRT device topology that includes PJRT device information from all nodes, and (2) generate NCCL ids.

DLPack

A few C APIs were added to PJRT to support DLPack.

  • PJRT_Client_CreateViewOfDeviceBuffer supports receiving buffers from DLPack.
  • Exporting buffers to DLPack requires: PJRT_Buffer_IncreaseExternalReferenceCount, PJRT_Buffer_DecreaseExternalReferenceCount to get a PJRT_Buffer_OpaqueDeviceMemoryDataPointer.

Extension

PJRT API provides an extension mechanism that the plugin can provide extensions which are optional or experimental features. These extensions can have their own compatibility guarantee and do not need to support the ABI compatibility of PJRT API.

Industry Adoption

PJRT is the only interface for JAX, the primary interface for TensorFlow and fully supported for PyTorch through PyTorch/XLA. PJRT is not tied to a specific compiler and runtime. The toolchain-independent architecture and open-source availability as part of the OpenXLA Project allows it to be leveraged by any hardware, framework or compiler, with extensibility for unique features. This has allowed PJRT to be adopted by various industry partners through close collaboration. A brief account of Apple’s adoption of PJRT follows.

JAX on Apple Silicon

Apple’s PJRT plugin for the Metal training backend accelerates JAX models on Apple silicon and AMD GPUs. This empowers any ML developers to leverage the full potential of Apple silicon and AMD GPUs on their Apple hardware to accelerate JAX models for faster experimentation. The integration and user experience to accelerate JAX on Apple silicon GPUs is similar to the existing PyTorch and TensorFlow implementations.

The Metal plug-in uses the OpenXLA compiler and PJRT runtime to optimize and accelerate JAX workloads on GPU. When a JAX program is executed, the JAX graph is lowered into StableHLO, which is then passed to PJRT for compilation and execution. The StableHLO is converted to MPSGraph executables and the Metal runtime APIs are invoked to dispatch to the GPU.

Performance

The Metal backend with PJRT plugin provides impressive performance speedup for JAX. On an Apple MacBook Pro with M2 Max, training common networks in JAX see performance speedups of up to 28x, with an average of 10x over a CPU baseline. This empowers any ML developer to leverage the full potential of Apple Silicon on their Apple hardware to accelerate JAX models for faster experimentation.

graph of performance speedups of up to 28x on Apple MacBook Pro with M2 Max over CPU for JAX training.
Figure 1: Performance speedups of up to 28x on Apple MacBook Pro with M2 Max over CPU for JAX training.

Getting Started

Adding Metal support to JAX is as simple as a single pip install:

python -m pip install jax-metal
python -c 'import jax; print(jax.numpy.arange(10))'

For more details on environment setup and installation of JAX on Apple hardware, please refer to the Metal Developer Resources page.

Google Cloud TPU

PJRT is the default runtime for PyTorch 2.0 on Google Cloud TPU. GitHub Readme has more details.

NVIDIA GPU

The NVIDIA GPU CUDA implementation in JAX is extracted and packaged as a PJRT plugin. The ML model developers can install the NVIDIA GPU CUDA plugin from pypi. This plugin uses the newly added features such as multi-node, DLPack, and extensions.

Intel GPU

Intel is leveraging PJRT in Intel® Extension for TensorFlow to provide the Intel GPU backend for TensorFlow, JAX and PyTorch. The example of executing a JAX program on Intel GPU demonstrates how this greatly simplifies the framework and hardware integration.

PJRT Resources

PJRT is available on GitHub: source code for the API, integration guides and issues. If you develop ML frameworks, compilers, runtimes or are interested in improving portability of workloads across hardware, we want your feedback. We encourage you to contribute code, design ideas and feature suggestions. We also invite you to join the PJRT mailing list to stay updated with the latest product and community announcements and to help shape the future of an interoperable ML infrastructure.

Acknowledgements

Chalana Bezawada, Daniel Doctor, Kulin Seth, Shuhan Ding from Apple 
Penporn Koanantakool, Peter Hawkins, Skye Wanderman-Milne, Xiao Yu from Google.

By Aman Verma – Product Manager, Machine Learning Infrastructure, Google and Jieying Luo – Software Engineer, Machine Learning Infrastructure, Google

Search Off The Record podcast – behind the scenes with Search Relations

Search Off The Record podcast – behind the scenes with Search Relations