We believe in the power of choice for Machine Learning development, and continue to invest resources to make it easy for ML practitioners to train, deploy, and orchestrate models from a single unified data and AI cloud platform. We’re excited to announce our role as a founding member of the newly formed PyTorch Foundation, which will better position Google Cloud to make meaningful contributions to the PyTorch community. As a member of the board, we will deepen our open source investment to deliver on the Foundation’s mission to drive adoption of AI tooling by building an ecosystem of open source projects with PyTorch. We strongly believe in choice and will continue to invest in frameworks such as JAX and Tensorflow and support integrations with other OSS Projects including Spark, Airflow, XGBoost, and others.
In this blog, we provide an overview of existing resources to help you get started with PyTorch on Google Cloud. We also talk about how ML practitioners can leverage our end-to-end ML platform to train, tune, and deploy PyTorch models.
PyTorch on Google Cloud
Open source in the cloud is important because it gives you flexibility and control over where you train and deploy your ML workloads. PyTorch is extensively used in the research space and in recent years it has gained immense traction in the industry due to its ease of use and deployment. In fact, according to a survey of Kaggle users, PyTorch is the fastest growing ML framework today.
ML practitioners using PyTorch tell us that it can be challenging to advance their ML project past experimentation. This is why Google Cloud has built integrations with PyTorch that make it easier to train, deploy, and orchestrate models in production. Some examples are:
PyTorch integrates directly with Vertex AI, a fully managed ML platform that provides the tools you need to take a model from PyTorch to production, like the Pytorch DL containers or the Vertex AI workbench PyTorch one-click JupyterLab environment.
PyTorch/XLA, an open source library, uses the XLA deep learning compiler to enable PyTorch to run on Cloud TPUs. Cloud TPUs are custom accelerators designed by Google, optimized for perf/TCO with large scale ML workload PyTorch/XLA also enables XLA driven optimizations on GPUs.
TorchX provides an adapter to run and orchestrate TorchX components as part of Kubeflow Pipelines that you can easily scale on Vertex AI Pipelines.
With our OSS contributions to Apache Beam, we have made PyTorch models easy to deploy in batch or stream, data processing pipelines. Running on Google Dataflow, these pipelines will scale to very large workloads in a fully managed and simple to maintain environment.
To learn more and start using PyTorch on Google Cloud, check out the resources below:
PyTorch/XLA: Performance debugging on Cloud TPU VM: Part I: In the first part of the performance debugging series on Cloud TPU, we lay out the conceptual framework for PyTorch/XLA in the context of training performance. We introduced a case study to make sense of preliminary profiler logs and identify the corrective actions.
Posted by Zhengzhong Tu and Yinxiao Li, Software Engineers, Google Research
Convolutional neural networks have been the dominant machine learning architecture for computer vision since the introduction of AlexNet in 2012. Recently, inspired by the evolution of Transformers in natural language processing, attention mechanisms have been prominently incorporated into vision models. These attention methods boost some parts of the input data while minimizing other parts so that the network can focus on small but important parts of the data. The Vision Transformer (ViT) has created a new landscape of model designs for computer vision that is completely free of convolution. ViT regards image patches as a sequence of words, and applies a Transformer encoder on top. When trained on sufficiently large datasets, ViT demonstrates compelling performance on image recognition.
While convolutions and attention are both sufficient for good performance, neither of them are necessary. For example, MLP-Mixer adopts a simple multi-layer perceptron (MLP) to mix image patches across all the spatial locations, resulting in an all-MLP architecture. It is a competitive alternative to existing state-of-the-art vision models in terms of the trade-off between accuracy and computation required for training and inference. However, both ViT and the MLP models struggle to scale to higher input resolution because the computational complexity increases quadratically with respect to the image size.
Today we present a new multi-axis approach that is simple and effective, improves on the original ViT and MLP models, can better adapt to high-resolution, dense prediction tasks, and can naturally adapt to different input sizes with high flexibility and low complexity. Based on this approach, we have built two backbone models for high-level and low-level vision tasks. We describe the first in “MaxViT: Multi-Axis Vision Transformer”, to be presented in ECCV 2022, and show it significantly improves the state of the art for high-level tasks, such as image classification, object detection, segmentation, quality assessment, and generation. The second, presented in “MAXIM: Multi-Axis MLP for Image Processing” at CVPR 2022, is based on a UNet-like architecture and achieves competitive performance on low-level imaging tasks including denoising, deblurring, dehazing, deraining, and low-light enhancement. To facilitate further research on efficient Transformer and MLP models, we have open-sourced the code and models for both MaxViT and MAXIM.
A demo of image deblurring using MAXIM frame by frame.
Overview Our new approach is based on multi-axis attention, which decomposes the full-size attention (each pixel attends to all the pixels) used in ViT into two sparse forms — local and (sparse) global. As shown in the figure below, the multi-axis attention contains a sequential stack of block attention and grid attention. The block attention works within non-overlapping windows (small patches in intermediate feature maps) to capture local patterns, while the grid attention works on a sparsely sampled uniform grid for long-range (global) interactions. The window sizes of grid and block attentions can be fully controlled as hyperparameters to ensure a linear computational complexity to the input size.
The proposed multi-axis attention conducts blocked local and dilated global attention sequentially followed by a FFN, with only a linear complexity. The pixels in the same colors are attended together.
Such low-complexity attention can significantly improve its wide applicability to many vision tasks, especially for high-resolution visual predictions, demonstrating greater generality than the original attention used in ViT. We build two backbone instantiations out of this multi-axis attention approach – MaxViT and MAXIM, for high-level and low-level tasks, respectively.
MaxViT In MaxViT, we first build a single MaxViT block (shown below) by concatenating MBConv (proposed by EfficientNet, V2) with the multi-axis attention. This single block can encode local and global visual information regardless of input resolution. We then simply stack repeated blocks composed of attention and convolutions in a hierarchical architecture (similar to ResNet, CoAtNet), yielding our homogenous MaxViT architecture. Notably, MaxViT is distinguished from previous hierarchical approaches as it can “see” globally throughout the entire network, even in earlier, high-resolution stages, demonstrating stronger model capacity on various tasks.
The meta-architecture of MaxViT.
MAXIM Our second backbone, MAXIM, is a generic UNet-like architecture tailored for low-level image-to-image prediction tasks. MAXIM explores parallel designs of the local and global approaches using the gated multi-layer perceptron (gMLP) network (patching-mixing MLP with a gating mechanism). Another contribution of MAXIM is the cross-gating block that can be used to apply interactions between two different input signals. This block can serve as an efficient alternative to the cross-attention module as it only employs the cheap gated MLP operators to interact with various inputs without relying on the computationally heavy cross-attention. Moreover, all the proposed components including the gated MLP and cross-gating blocks in MAXIM enjoy linear complexity to image size, making it even more efficient when processing high-resolution pictures.
Results We demonstrate the effectiveness of MaxViT on a broad range of vision tasks. On image classification, MaxViT achieves state-of-the-art results under various settings: with only ImageNet-1K training, MaxViT attains 86.5% top-1 accuracy; with ImageNet-21K (14M images, 21k classes) pre-training, MaxViT achieves 88.7% top-1 accuracy; and with JFT (300M images, 18k classes) pre-training, our largest model MaxViT-XL achieves a high accuracy of 89.5% with 475M parameters.
Performance comparison of MaxViT with state-of-the-art models on ImageNet-1K. Top: Accuracy vs. FLOPs performance scaling with 224x224 image resolution. Bottom: Accuracy vs. parameters scaling curve under ImageNet-1K fine-tuning setting.
For downstream tasks, MaxViT as a backbone delivers favorable performance on a broad spectrum of tasks. For object detection and segmentation on the COCO dataset, the MaxViT backbone achieves 53.4 AP, outperforming other base-level models while requiring only about 60% the computational cost. For image aesthetics assessment, the MaxViT model advances the state-of-the-art MUSIQ model by 3.5% in terms of linear correlation with human opinion scores. The standalone MaxViT building block also demonstrates effective performance on image generation, achieving better FID and IS scores on the ImageNet-1K unconditional generation task with a significantly lower number of parameters than the state-of-the-art model, HiT.
The UNet-like MAXIM backbone, customized for image processing tasks, has also demonstrated state-of-the-art results on 15 out of 20 tested datasets, including denoising, deblurring, deraining, dehazing, and low-light enhancement, while requiring fewer or comparable number of parameters and FLOPs than competitive models. Images restored by MAXIM show more recovered details with less visual artifacts.
Visual results of MAXIM for image deblurring, deraining, and low-light enhancement.
Summary Recent works in the last two or so years have shown that ConvNets and Vision Transformers can achieve similar performance. Our work presents a unified design that takes advantage of the best of both worlds — efficient convolution and sparse attention — and demonstrates that a model built on top, namely MaxViT, can achieve state-of-the-art performance on a variety of vision tasks. More importantly, MaxViT scales well to very large data sizes. We also show that an alternative multi-axis design using MLP operators, MAXIM, achieves state-of-the-art performance on a broad range of low-level vision tasks.
Even though we present our models in the context of vision tasks, the proposed multi-axis approach can easily extend to language modeling to capture both local and global dependencies in linear time. Motivated by the work here, we expect that it is worthwhile to study other forms of sparse attention in higher-dimensional or multimodal signals such as videos, point clouds, and vision-language models.
We have open-sourced the code and models of MAXIM and MaxViT to facilitate future research on efficient attention and MLP models.
Acknowledgments We would like to thank our co-authors: Hossein Talebi, Han Zhang, Feng Yang, Peyman Milanfar, and Alan Bovik. We would also like to acknowledge the valuable discussion and support from Xianzhi Du, Long Zhao, Wuyang Chen, Hanxiao Liu, Zihang Dai, Anurag Arnab, Sungjoon Choi, Junjie Ke, Mauricio Delbracio, Irene Zhu, Innfarn Yoo, Huiwen Chang, and Ce Liu.
Posted by Zhengzhong Tu and Yinxiao Li, Software Engineers, Google Research
Convolutional neural networks have been the dominant machine learning architecture for computer vision since the introduction of AlexNet in 2012. Recently, inspired by the evolution of Transformers in natural language processing, attention mechanisms have been prominently incorporated into vision models. These attention methods boost some parts of the input data while minimizing other parts so that the network can focus on small but important parts of the data. The Vision Transformer (ViT) has created a new landscape of model designs for computer vision that is completely free of convolution. ViT regards image patches as a sequence of words, and applies a Transformer encoder on top. When trained on sufficiently large datasets, ViT demonstrates compelling performance on image recognition.
While convolutions and attention are both sufficient for good performance, neither of them are necessary. For example, MLP-Mixer adopts a simple multi-layer perceptron (MLP) to mix image patches across all the spatial locations, resulting in an all-MLP architecture. It is a competitive alternative to existing state-of-the-art vision models in terms of the trade-off between accuracy and computation required for training and inference. However, both ViT and the MLP models struggle to scale to higher input resolution because the computational complexity increases quadratically with respect to the image size.
Today we present a new multi-axis approach that is simple and effective, improves on the original ViT and MLP models, can better adapt to high-resolution, dense prediction tasks, and can naturally adapt to different input sizes with high flexibility and low complexity. Based on this approach, we have built two backbone models for high-level and low-level vision tasks. We describe the first in “MaxViT: Multi-Axis Vision Transformer”, to be presented in ECCV 2022, and show it significantly improves the state of the art for high-level tasks, such as image classification, object detection, segmentation, quality assessment, and generation. The second, presented in “MAXIM: Multi-Axis MLP for Image Processing” at CVPR 2022, is based on a UNet-like architecture and achieves competitive performance on low-level imaging tasks including denoising, deblurring, dehazing, deraining, and low-light enhancement. To facilitate further research on efficient Transformer and MLP models, we have open-sourced the code and models for both MaxViT and MAXIM.
A demo of image deblurring using MAXIM frame by frame.
Overview Our new approach is based on multi-axis attention, which decomposes the full-size attention (each pixel attends to all the pixels) used in ViT into two sparse forms — local and (sparse) global. As shown in the figure below, the multi-axis attention contains a sequential stack of block attention and grid attention. The block attention works within non-overlapping windows (small patches in intermediate feature maps) to capture local patterns, while the grid attention works on a sparsely sampled uniform grid for long-range (global) interactions. The window sizes of grid and block attentions can be fully controlled as hyperparameters to ensure a linear computational complexity to the input size.
The proposed multi-axis attention conducts blocked local and dilated global attention sequentially followed by a FFN, with only a linear complexity. The pixels in the same colors are attended together.
Such low-complexity attention can significantly improve its wide applicability to many vision tasks, especially for high-resolution visual predictions, demonstrating greater generality than the original attention used in ViT. We build two backbone instantiations out of this multi-axis attention approach – MaxViT and MAXIM, for high-level and low-level tasks, respectively.
MaxViT In MaxViT, we first build a single MaxViT block (shown below) by concatenating MBConv (proposed by EfficientNet, V2) with the multi-axis attention. This single block can encode local and global visual information regardless of input resolution. We then simply stack repeated blocks composed of attention and convolutions in a hierarchical architecture (similar to ResNet, CoAtNet), yielding our homogenous MaxViT architecture. Notably, MaxViT is distinguished from previous hierarchical approaches as it can “see” globally throughout the entire network, even in earlier, high-resolution stages, demonstrating stronger model capacity on various tasks.
The meta-architecture of MaxViT.
MAXIM Our second backbone, MAXIM, is a generic UNet-like architecture tailored for low-level image-to-image prediction tasks. MAXIM explores parallel designs of the local and global approaches using the gated multi-layer perceptron (gMLP) network (patching-mixing MLP with a gating mechanism). Another contribution of MAXIM is the cross-gating block that can be used to apply interactions between two different input signals. This block can serve as an efficient alternative to the cross-attention module as it only employs the cheap gated MLP operators to interact with various inputs without relying on the computationally heavy cross-attention. Moreover, all the proposed components including the gated MLP and cross-gating blocks in MAXIM enjoy linear complexity to image size, making it even more efficient when processing high-resolution pictures.
Results We demonstrate the effectiveness of MaxViT on a broad range of vision tasks. On image classification, MaxViT achieves state-of-the-art results under various settings: with only ImageNet-1K training, MaxViT attains 86.5% top-1 accuracy; with ImageNet-21K (14M images, 21k classes) pre-training, MaxViT achieves 88.7% top-1 accuracy; and with JFT (300M images, 18k classes) pre-training, our largest model MaxViT-XL achieves a high accuracy of 89.5% with 475M parameters.
Performance comparison of MaxViT with state-of-the-art models on ImageNet-1K. Top: Accuracy vs. FLOPs performance scaling with 224x224 image resolution. Bottom: Accuracy vs. parameters scaling curve under ImageNet-1K fine-tuning setting.
For downstream tasks, MaxViT as a backbone delivers favorable performance on a broad spectrum of tasks. For object detection and segmentation on the COCO dataset, the MaxViT backbone achieves 53.4 AP, outperforming other base-level models while requiring only about 60% the computational cost. For image aesthetics assessment, the MaxViT model advances the state-of-the-art MUSIQ model by 3.5% in terms of linear correlation with human opinion scores. The standalone MaxViT building block also demonstrates effective performance on image generation, achieving better FID and IS scores on the ImageNet-1K unconditional generation task with a significantly lower number of parameters than the state-of-the-art model, HiT.
The UNet-like MAXIM backbone, customized for image processing tasks, has also demonstrated state-of-the-art results on 15 out of 20 tested datasets, including denoising, deblurring, deraining, dehazing, and low-light enhancement, while requiring fewer or comparable number of parameters and FLOPs than competitive models. Images restored by MAXIM show more recovered details with less visual artifacts.
Visual results of MAXIM for image deblurring, deraining, and low-light enhancement.
Summary Recent works in the last two or so years have shown that ConvNets and Vision Transformers can achieve similar performance. Our work presents a unified design that takes advantage of the best of both worlds — efficient convolution and sparse attention — and demonstrates that a model built on top, namely MaxViT, can achieve state-of-the-art performance on a variety of vision tasks. More importantly, MaxViT scales well to very large data sizes. We also show that an alternative multi-axis design using MLP operators, MAXIM, achieves state-of-the-art performance on a broad range of low-level vision tasks.
Even though we present our models in the context of vision tasks, the proposed multi-axis approach can easily extend to language modeling to capture both local and global dependencies in linear time. Motivated by the work here, we expect that it is worthwhile to study other forms of sparse attention in higher-dimensional or multimodal signals such as videos, point clouds, and vision-language models.
We have open-sourced the code and models of MAXIM and MaxViT to facilitate future research on efficient attention and MLP models.
Acknowledgments We would like to thank our co-authors: Hossein Talebi, Han Zhang, Feng Yang, Peyman Milanfar, and Alan Bovik. We would also like to acknowledge the valuable discussion and support from Xianzhi Du, Long Zhao, Wuyang Chen, Hanxiao Liu, Zihang Dai, Anurag Arnab, Sungjoon Choi, Junjie Ke, Mauricio Delbracio, Irene Zhu, Innfarn Yoo, Huiwen Chang, and Ce Liu.
Welcome to the 13th Issue: ‘Google Dev Library letters’ is a technology newsletter curated to bring you some of the best projects developed with Google tech and submitted to the Google Dev Library platform. We are back with another boost of inspiration for your next project!
Hero Content of the month
Check out shortlisted content from the Google technologies of your choice.
Contact Store is a modern API that makes access to contacts on Android devices simple to use. It solves for the most frequent use cases and makes developing enjoyable.
CustomProgressIndicator library is a simple, customizable progress indicator that gives android applications a nice feel. It saves developers time by creating a unique, customizable loading view.
Get access to over 1,335 icons in one centralized place - the Cupertino Icons Gallery is an open source, cross-platform space to find all the icons used in Flutter.
Learn how to build a system by considering two MLOps scenarios - if the model needs to be replaced later and if the model itself has to evolve with the data.
Learn how to create a stream in 6 simple steps now that Google Cloud recently made Datastream CDC generally available.
Curators Corner
Meet our curators who have been working behind the scenes to bring you the best content submissions
Android
"Android development changes fast and it's great to see developers write blogs to help others learn.
It's a pleasure to be part of the Android community. I enjoy seeing the android community. I enjoy seeing the Android community flourish by collaborating with each other and sharing their learnings"
Andres Sandoval
Sr. Strategist, Google
Machine Learning
"We are loving the TensorFlow.js submissions we have seen so far, and have no doubt future ones will continue to push the boundaries of what's possible in this space, and because it is web powered anyone anywhere can try the demos typically with the click of a link!"
Jason Mayes
Web ML Developer Relations Lead, Google
Liked what you read? Checkout the latest projects and community-authored content by visiting our home page or subscribing to our newsletter.
Posted by Tingbo Hou and Juhyun Lee, Software Engineers, Google
In recent years video conferencing has played an increasingly important role in both work and personal communication for many users. Over the past two years, we have enhanced this experience in Google Meet by introducing privacy-preserving machine learning (ML) powered background features, also known as “virtual green screen”, which allows users to blur their backgrounds or replace them with other images. What is unique about this solution is that it runs directly in the browser without the need to install additional software.
So far, these ML-powered features have relied on CPU inference made possible by leveraging neural network sparsity, a common solution that works across devices, from entry level computers to high-end workstations. This enables our features to reach the widest audience. However, mid-tier and high-end devices often have powerful GPUs that remain untapped for ML inference, and existing functionality allows web browsers to access GPUs via shaders (WebGL).
With the latest update to Google Meet, we are now harnessing the power of GPUs to significantly improve the fidelity and performance of these background effects. As we detail in “Efficient Heterogeneous Video Segmentation at the Edge”, these advances are powered by two major components: 1) a novel real-time video segmentation model and 2) a new, highly efficient approach for in-browser ML acceleration using WebGL. We leverage this capability to develop fast ML inference via fragment shaders. This combination results in substantial gains in accuracy and latency, leading to crisper foreground boundaries.
CPU segmentation vs. HD segmentation in Meet.
Moving Towards Higher Quality Video Segmentation Models To predict finer details, our new segmentation model now operates on high definition (HD) input images, rather than lower-resolution images, effectively doubling the resolution over the previous model. To accommodate this, the model must be of higher capacity to extract features with sufficient detail. Roughly speaking, doubling the input resolution quadruples the computation cost during inference.
Inference of high-resolution models using the CPU is not feasible for many devices. The CPU may have a few high-performance cores that enable it to execute arbitrary complex code efficiently, but it is limited in its ability for the parallel computation required for HD segmentation. In contrast, GPUs have many, relatively low-performance cores coupled with a wide memory interface, making them uniquely suitable for high-resolution convolutional models. Therefore, for mid-tier and high-end devices, we adopt a significantly faster pure GPU pipeline, which is integrated using WebGL.
This change inspired us to revisit some of the prior design decisions for the model architecture.
Backbone: We compared several widely-used backbones for on-device networks and found EfficientNet-Lite to be a better fit for the GPU because it removes the squeeze-and-excitation block, a component that is inefficient on WebGL (more below).
Decoder: We switched to a multi-layer perceptron (MLP) decoder consisting of 1x1 convolutions instead of using simple bilinear upsampling or the more expensive squeeze-and-excitation blocks. MLP has been successfully adopted in other segmentation architectures, like DeepLab and PointRend, and is efficient to compute on both CPU and GPU.
Model size: With our new WebGL inference and the GPU-friendly model architecture, we were able to afford a larger model without sacrificing the real-time frame rate necessary for smooth video segmentation. We explored the width and the depth parameters using a neural architecture search.
HD segmentation model architecture.
In aggregate, these changes substantially improve the mean Intersection over Union (IoU) metric by 3%, resulting in less uncertainty and crisper boundaries around hair and fingers.
We have also released the accompanying model card for this segmentation model, which details our fairness evaluations. Our analysis shows that the model is consistent in its performance across the various regions, skin-tones, and genders, with only small deviations in IoU metrics.
Comparison of the previous segmentation model vs. the new HD segmentation model on a Macbook Pro (2018).
Accelerating Web ML with WebGL One common challenge for web-based inference is that web technologies can incur a performance penalty when compared to apps running natively on-device. For GPUs, this penalty is substantial, only achieving around 25% of native OpenGL performance. This is because WebGL, the current GPU standard for Web-based inference, was primarily designed for image rendering, not arbitrary ML workloads. In particular, WebGL does not include compute shaders, which allow for general purpose computation and enable ML workloads in mobile and native apps.
To overcome this challenge, we accelerated low-level neural network kernels with fragment shaders that typically compute the output properties of a pixel like color and depth, and then applied novel optimizations inspired by the graphics community. As ML workloads on GPUs are often bound by memory bandwidth rather than compute, we focused on rendering techniques that would improve the memory access, such as Multiple Render Targets (MRT).
MRT is a feature in modern GPUs that allows rendering images to multiple output textures (OpenGL objects that represent images) at once. While MRT was originally designed to support advanced graphics rendering such as deferred shading, we found that we could leverage this feature to drastically reduce the memory bandwidth usage of our fragment shader implementations for critical operations, like convolutions and fully connected layers. We do so by treating intermediate tensors as multiple OpenGL textures.
In the figure below, we show an example of intermediate tensors having four underlying GL textures each. With MRT, the number of GPU threads, and thus effectively the number of memory requests for weights, is reduced by a factor of four and saves memory bandwidth usage. Although this introduces considerable complexities in the code, it helps us reach over 90% of native OpenGL performance,closing the gap with native applications.
Left: A classic implementation of Conv2D with 1-to-1 correspondence of tensor and an OpenGL texture. Red, yellow, green, and blue boxes denote different locations in a single texture each for intermediate tensor A and B. Right: Our implementation of Conv2D with MRT where intermediate tensors A and B are realized with a set of 4 GL textures each, depicted as red, yellow, green, and blue boxes. Note that this reduces the request count for weights by 4x.
Conclusion We have made rapid strides in improving the quality of real-time segmentation models by leveraging the GPU on mid-tier and high-end devices for use with Google Meet. We look forward to the possibilities that will be enabled by upcoming technologies like WebGPU, which bring compute shaders to the web. Beyond GPU inference, we're also working on improving the segmentation quality for lower powered devices with quantized inference via XNNPACKWebAssembly.
Acknowledgements Special thanks to those on the Meet team and others who worked on this project, in particular Sebastian Jansson, Sami Kalliomäki, Rikard Lundmark, Stephan Reiter, Fabian Bergmark, Ben Wagner, Stefan Holmer, Dan Gunnarsson, Stéphane Hulaud, and to all our team members who made this possible: Siargey Pisarchyk, Raman Sarokin, Artsiom Ablavatski, Jamie Lin, Tyler Mullen, Gregory Karpiak, Andrei Kulik, Karthik Raveendran, Trent Tolley, and Matthias Grundmann.
Posted by Brian Ichter and Karol Hausman, Research Scientists, Google Research, Brain Team
Over the last several years, we have seen significant progress in applying machine learning to robotics. However, robotic systems today are capable of executing only very short, hard-coded commands, such as “Pick up an apple,” because they tend to perform best with clear tasks and rewards. They struggle with learning to perform long-horizon tasks and reasoning about abstract goals, such as a user prompt like “I just worked out, can you get me a healthy snack?”
Meanwhile, recent progress in training language models (LMs) has led to systems that can perform a wide range of language understanding and generation tasks with impressive results. However, these language models are inherently not grounded in the physical world due to the nature of their training process: a language model generally does not interact with its environment nor observe the outcome of its responses. This can result in it generating instructions that may be illogical, impractical or unsafe for a robot to complete in a physical context. For example, when prompted with “I spilled my drink, can you help?” the language model GPT-3 responds with “You could try using a vacuum cleaner,” a suggestion that may be unsafe or impossible for the robot to execute. When asking the FLAN language model the same question, it apologizes for the spill with "I'm sorry, I didn't mean to spill it,” which is not a very useful response. Therefore, we asked ourselves, is there an effective way to combine advanced language models with robot learning algorithms to leverage the benefits of both?
In “Do As I Can, Not As I Say: Grounding Language in Robotic Affordances”, we present a novel approach, developed in partnership with Everyday Robots, that leverages advanced language model knowledge to enable a physical agent, such as a robot, to follow high-level textual instructions for physically-grounded tasks, while grounding the language model in tasks that are feasible within a specific real-world context. We evaluate our method, which we call PaLM-SayCan, by placing robots in a real kitchen setting and giving them tasks expressed in natural language. We observe highly interpretable results for temporally-extended complex and abstract tasks, like “I just worked out, please bring me a snack and a drink to recover.” Specifically, we demonstrate that grounding the language model in the real world nearly halves errors over non-grounded baselines. We are also excited to release a robot simulation setup where the research community can test this approach.
With PaLM-SayCan, the robot acts as the language model’s “hands and eyes,” while the language model supplies high-level semantic knowledge about the task.
A Dialog Between User and Robot, Facilitated by the Language Model Our approach uses the knowledge contained in language models (Say) to determine and score actions that are useful towards high-level instructions. It also uses an affordance function (Can) that enables real-world-grounding and determines which actions are possible to execute in a given environment. Using the the PaLM language model, we call this PaLM-SayCan.
Our approach selects skills based on what the language model scores as useful to the high level instruction and what the affordance model scores as possible.
Our system can be seen as a dialog between the user and robot, facilitated by the language model. The user starts by giving an instruction that the language model turns into a sequence of steps for the robot to execute. This sequence is filtered using the robot’s skillset to determine the most feasible plan given its current state and environment. The model determines the probability of a specific skill successfully making progress toward completing the instruction by multiplying two probabilities: (1) task-grounding (i.e., a skill language description) and (2) world-grounding (i.e., skill feasibility in the current state).
There are additional benefits of our approach in terms of its safety and interpretability. First, by allowing the LM to score different options rather than generate the most likely output, we effectively constrain the LM to only output one of the pre-selected responses. In addition, the user can easily understand the decision making process by looking at the separate language and affordance scores, rather than a single output.
PaLM-SayCan is also interpretable: at each step, we can see the top options it considers based on their language score (blue), affordance score (red), and combined score (green).
Training Policies and Value Functions Each skill in the agent’s skillset is defined as a policy with a short language description (e.g., “pick up the can”), represented as embeddings, and an affordance function that indicates the probability of completing the skill from the robot’s current state. To learn the affordance functions, we use sparse reward functions set to 1.0 for a successful execution, and 0.0 otherwise.
We use image-based behavioral cloning (BC) to train the language-conditioned policies and temporal-difference-based (TD) reinforcement learning (RL) to train the value functions. To train the policies, we collected data from 68,000 demos performed by 10 robots over 11 months and added 12,000 successful episodes, filtered from a set of autonomous episodes of learned policies. We then learned the language conditioned value functions using MT-Opt in the Everyday Robots simulator. The simulator complements our real robot fleet with a simulated version of the skills and environment, which is transformed using RetinaGAN to reduce the simulation-to-real gap. We bootstrapped simulation policies’ performance by using demonstrations to provide initial successes, and then continuously improved RL performance with online data collection in simulation.
Given a high-level instruction, our approach combines the probabilities from the language model with the probabilities from the value function (VF) to select the next skill to perform. This process is repeated until the high-level instruction is successfully completed.
Performance on Temporally-Extended, Complex, and Abstract Instructions To test our approach, we use robots from Everyday Robots paired with PaLM. We place the robots in a kitchen environment containing common objects and evaluate them on 101 instructions to test their performance across various robot and environment states, instruction language complexity and time horizon. Specifically, these instructions were designed to showcase the ambiguity and complexity of language rather than to provide simple, imperative queries, enabling queries such as “I just worked out, how would you bring me a snack and a drink to recover?” instead of “Can you bring me water and an apple?”
We use two metrics to evaluate the system’s performance: (1) the plan success rate, indicating whether the robot chose the right skills for the instruction, and (2) the execution success rate, indicating whether it performed the instruction successfully. We compare two language models, PaLM and FLAN (a smaller language model fine-tuned on instruction answering) with and without the affordance grounding as well as the underlying policies running directly with natural language (Behavioral Cloning in the table below). The results show that the system using PaLM with affordance grounding (PaLM-SayCan) chooses the correct sequence of skills 84% of the time and executes them successfully 74% of the time, reducing errors by 50% compared to FLAN and compared to PaLM without robotic grounding. This is particularly exciting because it represents the first time we can see how an improvement in language models translates to a similar improvement in robotics. This result indicates a potential future where robotics is able to ride the wave of progress that we have been observing in language models, bringing these subfields of research closer together.
Algorithm
Plan
Execute
PaLM-SayCan
84%
74%
PaLM
67%
-
FLAN-SayCan
70%
61%
FLAN
38%
-
Behavioral Cloning
0%
0%
PaLM-SayCan halves errors compared to PaLM without affordances and compared to FLAN over 101 tasks.
SayCan demonstrated successful planning for 84% of the 101 test instructions when combined with PaLM.
If you're interested in learning more about this project from the researchers themselves, please check out the video below:
Conclusion and Future Work We’re excited about the progress that we’ve seen with PaLM-SayCan, an interpretable and general approach to leveraging knowledge from language models that enables a robot to follow high-level textual instructions to perform physically-grounded tasks. Our experiments on a number of real-world robotic tasks demonstrate the ability to plan and complete long-horizon, abstract, natural language instructions at a high success rate. We believe that PaLM-SayCan’s interpretability allows for safe real-world user interaction with robots. As we explore future directions for this work, we hope to better understand how information gained via the robot’s real-world experience could be leveraged to improve the language model and to what extent natural language is the right ontology for programming robots. We have open-sourced a robot simulation setup, which we hope will provide researchers with a valuable resource for future research that combines robotic learning with advanced language models. The research community can visit the project’s GitHub page and website to learn more.
Acknowledgements We’d like to thank our coauthors Michael Ahn, Anthony Brohan, Noah Brown, Yevgen Chebotar, Omar Cortes, Byron David, Chelsea Finn, Kelly Fu, Keerthana Gopalakrishnan, Alex Herzog, Daniel Ho, Jasmine Hsu, Julian Ibarz, Alex Irpan, Eric Jang, Rosario Jauregui Ruano, Kyle Jeffrey, Sally Jesmonth, Nikhil J Joshi, Ryan Julian, Dmitry Kalashnikov, Yuheng Kuang, Kuang-Huei Lee, Sergey Levine, Yao Lu, Linda Luu, Carolina Parada, Peter Pastor, Jornell Quiambao, Kanishka Rao, Jarek Rettinghouse, Diego Reyes, Pierre Sermanet, Nicolas Sievers, Clayton Tan, Alexander Toshev, Vincent Vanhoucke, Fei Xia, Ted Xiao, Peng Xu, Sichun Xu, Mengyuan Yan, and Andy Zeng. We’d also like to thank Yunfei Bai, Matt Bennice, Maarten Bosma, Justin Boyd, Bill Byrne, Kendra Byrne, Noah Constant, Pete Florence, Laura Graesser, Rico Jonschkowski, Daniel Kappler, Hugo Larochelle, Benjamin Lee, Adrian Li, Suraj Nair, Krista Reymann, Jeff Seto, Dhruv Shah, Ian Storz, Razvan Surdulescu, and Vincent Zhao for their help and support in various aspects of the project. And we’d like to thank Tom Small for creating many of the animations in this post.
Posted by Rolf Jagerman and Honglei Zhuang, Software Engineers, Google Research
Ranking is a core problem across a variety of domains, such as search engines, recommendation systems, or question answering. As such, researchers often utilize learning-to-rank (LTR), a set of supervised machine learning techniques that optimize for the utility of an entire list of items (rather than a single item at a time). A noticeable recent focus is on combining LTR with deep learning. Existing libraries, most notably TF-Ranking, offer researchers and practitioners the necessary tools to use LTR in their work. However, none of the existing LTR libraries work natively with JAX, a new machine learning framework that provides an extensible system of function transformations that compose: automatic differentiation, JIT-compilation to GPU/TPU devices and more.
Today, we are excited to introduce Rax, a library for LTR in the JAX ecosystem. Rax brings decades of LTR research to the JAX ecosystem, making it possible to apply JAX to a variety of ranking problems and combine ranking techniques with recent advances in deep learning built upon JAX (e.g., T5X). Rax provides state-of-the-art ranking losses, a number of standard ranking metrics, and a set of function transformations to enable ranking metric optimization. All this functionality is provided with a well-documented and easy to use API that will look and feel familiar to JAX users. Please check out our paper for more technical details.
Learning-to-Rank Using Rax Rax is designed to solve LTR problems. To this end, Rax provides loss and metric functions that operate on batches of lists, not batches of individual data points as is common in other machine learning problems. An example of such a list is the multiple potential results from a search engine query. The figure below illustrates how tools from Rax can be used to train neural networks on ranking tasks. In this example, the green items (B, F) are very relevant, the yellow items (C, E) are somewhat relevant and the red items (A, D) are not relevant. A neural network is used to predict a relevancy score for each item, then these items are sorted by these scores to produce a ranking. A Rax ranking loss incorporates the entire list of scores to optimize the neural network, improving the overall ranking of the items. After several iterations of stochastic gradient descent, the neural network learns to score the items such that the resulting ranking is optimal: relevant items are placed at the top of the list and non-relevant items at the bottom.
Using Rax to optimize a neural network for a ranking task. The green items (B, F) are very relevant, the yellow items (C, E) are somewhat relevant and the red items (A, D) are not relevant.
Approximate Metric Optimization The quality of a ranking is commonly evaluated using ranking metrics, e.g., the normalized discounted cumulative gain (NDCG). An important objective of LTR is to optimize a neural network so that it scores highly on ranking metrics. However, ranking metrics like NDCG can present challenges because they are often discontinuous and flat, so stochastic gradient descent cannot directly be applied to these metrics. Rax provides state-of-the-art approximation techniques that make it possible to produce differentiable surrogates to ranking metrics that permit optimization via gradient descent. The figure below illustrates the use of rax.approx_t12n, a function transformation unique to Rax, which allows for the NDCG metric to be transformed into an approximate and differentiable form.
Using an approximation technique from Rax to transform the NDCG ranking metric into a differentiable and optimizable ranking loss (approx_t12n and gumbel_t12n).
First, notice how the NDCG metric (in green) is flat and discontinuous, making it hard to optimize using stochastic gradient descent. By applying the rax.approx_t12n transformation to the metric, we obtain ApproxNDCG, an approximate metric that is now differentiable with well-defined gradients (in red). However, it potentially has many local optima — points where the loss is locally optimal, but not globally optimal — in which the training process can get stuck. When the loss encounters such a local optimum, training procedures like stochastic gradient descent will have difficulty improving the neural network further.
To overcome this, we can obtain the gumbel-version of ApproxNDCG by using the rax.gumbel_t12n transformation. This gumbel version introduces noise in the ranking scores which causes the loss to sample many different rankings that may incur a non-zero cost (in blue). This stochastic treatment may help the loss escape local optima and often is a better choice when training a neural network on a ranking metric. Rax, by design, allows the approximate and gumbel transformations to be freely used with all metrics that are offered by the library, including metrics with a top-k cutoff value, like recall or precision. In fact, it is even possible to implement your own metrics and transform them to obtain gumbel-approximate versions that permit optimization without any extra effort.
Ranking in the JAX Ecosystem Rax is designed to integrate well in the JAX ecosystem and we prioritize interoperability with other JAX-based libraries. For example, a common workflow for researchers that use JAX is to use TensorFlow Datasets to load a dataset, Flax to build a neural network, and Optax to optimize the parameters of the network. Each of these libraries composes well with the others and the composition of these tools is what makes working with JAX both flexible and powerful. For researchers and practitioners of ranking systems, the JAX ecosystem was previously missing LTR functionality, and Rax fills this gap by providing a collection of ranking losses and metrics. We have carefully constructed Rax to function natively with standard JAX transformations such as jax.jit and jax.grad and various libraries like Flax and Optax. This means that users can freely use their favorite JAX and Rax tools together.
Ranking with T5 While giant language models such as T5 have shown great performance on natural language tasks, how to leverage ranking losses to improve their performance on ranking tasks, such as search or question answering, is under-explored. With Rax, it is possible to fully tap this potential. Rax is written as a JAX-first library, thus it is easy to integrate it with other JAX libraries. Since T5X is an implementation of T5 in the JAX ecosystem, Rax can work with it seamlessly.
To this end, we have an example that demonstrates how Rax can be used in T5X. By incorporating ranking losses and metrics, it is now possible to fine-tune T5 for ranking problems, and our results indicate that enhancing T5 with ranking losses can offer significant performance improvements. For example, on the MS-MARCO QNA v2.1 benchmark we are able to achieve a +1.2% NDCG and +1.7% MRR by fine-tuning a T5-Base model using the Rax listwise softmax cross-entropy loss instead of a pointwise sigmoid cross-entropy loss.
Fine-tuning a T5-Base model on MS-MARCO QNA v2.1 with a ranking loss (softmax, in blue) versus a non-ranking loss (pointwise sigmoid, in red).
Conclusion Overall, Rax is a new addition to the growing ecosystem of JAX libraries. Rax is entirely open source and available to everyone at github.com/google/rax. More technical details can also be found in our paper. We encourage everyone to explore the examples included in the github repository: (1) optimizing a neural network with Flax and Optax, (2) comparing different approximate metric optimization techniques, and (3) how to integrate Rax with T5X.
Posted by Rolf Jagerman and Honglei Zhuang, Software Engineers, Google Research
Ranking is a core problem across a variety of domains, such as search engines, recommendation systems, or question answering. As such, researchers often utilize learning-to-rank (LTR), a set of supervised machine learning techniques that optimize for the utility of an entire list of items (rather than a single item at a time). A noticeable recent focus is on combining LTR with deep learning. Existing libraries, most notably TF-Ranking, offer researchers and practitioners the necessary tools to use LTR in their work. However, none of the existing LTR libraries work natively with JAX, a new machine learning framework that provides an extensible system of function transformations that compose: automatic differentiation, JIT-compilation to GPU/TPU devices and more.
Today, we are excited to introduce Rax, a library for LTR in the JAX ecosystem. Rax brings decades of LTR research to the JAX ecosystem, making it possible to apply JAX to a variety of ranking problems and combine ranking techniques with recent advances in deep learning built upon JAX (e.g., T5X). Rax provides state-of-the-art ranking losses, a number of standard ranking metrics, and a set of function transformations to enable ranking metric optimization. All this functionality is provided with a well-documented and easy to use API that will look and feel familiar to JAX users. Please check out our paper for more technical details.
Learning-to-Rank Using Rax Rax is designed to solve LTR problems. To this end, Rax provides loss and metric functions that operate on batches of lists, not batches of individual data points as is common in other machine learning problems. An example of such a list is the multiple potential results from a search engine query. The figure below illustrates how tools from Rax can be used to train neural networks on ranking tasks. In this example, the green items (B, F) are very relevant, the yellow items (C, E) are somewhat relevant and the red items (A, D) are not relevant. A neural network is used to predict a relevancy score for each item, then these items are sorted by these scores to produce a ranking. A Rax ranking loss incorporates the entire list of scores to optimize the neural network, improving the overall ranking of the items. After several iterations of stochastic gradient descent, the neural network learns to score the items such that the resulting ranking is optimal: relevant items are placed at the top of the list and non-relevant items at the bottom.
Using Rax to optimize a neural network for a ranking task. The green items (B, F) are very relevant, the yellow items (C, E) are somewhat relevant and the red items (A, D) are not relevant.
Approximate Metric Optimization The quality of a ranking is commonly evaluated using ranking metrics, e.g., the normalized discounted cumulative gain (NDCG). An important objective of LTR is to optimize a neural network so that it scores highly on ranking metrics. However, ranking metrics like NDCG can present challenges because they are often discontinuous and flat, so stochastic gradient descent cannot directly be applied to these metrics. Rax provides state-of-the-art approximation techniques that make it possible to produce differentiable surrogates to ranking metrics that permit optimization via gradient descent. The figure below illustrates the use of rax.approx_t12n, a function transformation unique to Rax, which allows for the NDCG metric to be transformed into an approximate and differentiable form.
Using an approximation technique from Rax to transform the NDCG ranking metric into a differentiable and optimizable ranking loss (approx_t12n and gumbel_t12n).
First, notice how the NDCG metric (in green) is flat and discontinuous, making it hard to optimize using stochastic gradient descent. By applying the rax.approx_t12n transformation to the metric, we obtain ApproxNDCG, an approximate metric that is now differentiable with well-defined gradients (in red). However, it potentially has many local optima — points where the loss is locally optimal, but not globally optimal — in which the training process can get stuck. When the loss encounters such a local optimum, training procedures like stochastic gradient descent will have difficulty improving the neural network further.
To overcome this, we can obtain the gumbel-version of ApproxNDCG by using the rax.gumbel_t12n transformation. This gumbel version introduces noise in the ranking scores which causes the loss to sample many different rankings that may incur a non-zero cost (in blue). This stochastic treatment may help the loss escape local optima and often is a better choice when training a neural network on a ranking metric. Rax, by design, allows the approximate and gumbel transformations to be freely used with all metrics that are offered by the library, including metrics with a top-k cutoff value, like recall or precision. In fact, it is even possible to implement your own metrics and transform them to obtain gumbel-approximate versions that permit optimization without any extra effort.
Ranking in the JAX Ecosystem Rax is designed to integrate well in the JAX ecosystem and we prioritize interoperability with other JAX-based libraries. For example, a common workflow for researchers that use JAX is to use TensorFlow Datasets to load a dataset, Flax to build a neural network, and Optax to optimize the parameters of the network. Each of these libraries composes well with the others and the composition of these tools is what makes working with JAX both flexible and powerful. For researchers and practitioners of ranking systems, the JAX ecosystem was previously missing LTR functionality, and Rax fills this gap by providing a collection of ranking losses and metrics. We have carefully constructed Rax to function natively with standard JAX transformations such as jax.jit and jax.grad and various libraries like Flax and Optax. This means that users can freely use their favorite JAX and Rax tools together.
Ranking with T5 While giant language models such as T5 have shown great performance on natural language tasks, how to leverage ranking losses to improve their performance on ranking tasks, such as search or question answering, is under-explored. With Rax, it is possible to fully tap this potential. Rax is written as a JAX-first library, thus it is easy to integrate it with other JAX libraries. Since T5X is an implementation of T5 in the JAX ecosystem, Rax can work with it seamlessly.
To this end, we have an example that demonstrates how Rax can be used in T5X. By incorporating ranking losses and metrics, it is now possible to fine-tune T5 for ranking problems, and our results indicate that enhancing T5 with ranking losses can offer significant performance improvements. For example, on the MS-MARCO QNA v2.1 benchmark we are able to achieve a +1.2% NDCG and +1.7% MRR by fine-tuning a T5-Base model using the Rax listwise softmax cross-entropy loss instead of a pointwise sigmoid cross-entropy loss.
Fine-tuning a T5-Base model on MS-MARCO QNA v2.1 with a ranking loss (softmax, in blue) versus a non-ranking loss (pointwise sigmoid, in red).
Conclusion Overall, Rax is a new addition to the growing ecosystem of JAX libraries. Rax is entirely open source and available to everyone at github.com/google/rax. More technical details can also be found in our paper. We encourage everyone to explore the examples included in the github repository: (1) optimizing a neural network with Flax and Optax, (2) comparing different approximate metric optimization techniques, and (3) how to integrate Rax with T5X.
‘Google Dev Library Letters’ is curated to bring you some of the latest projects developed with Google tech submitted to Google Dev Library Platform. We hope this brings you the inspiration you need for your next project!
Posted by Arun Kandoor, Software Engineer, Google Research
The increasing demand for machine learning (ML) model inference on-device (for mobile devices, tablets, etc.) is driven by the rise of compute-intensive applications, the need to keep certain data on device for privacy and security reasons, and the desire to provide services when a network connection may not be available. However, on-device inference introduces a myriad of challenges, ranging from modeling to platform support requirements. These challenges relate to how different architectures are designed to optimize memory and computation, while still trying to maintain the quality of the model. From a platform perspective, the issue is identifying operations and building on top of them in a way that can generalize well across different product use cases.
In previous research, we combined a novel technique for generating embeddings (called projection-based embeddings) with efficient architectures like QRNN (pQRNN) and proved them to be competent for a number of classification problems. Augmenting these with distillation techniques provides an additional bump in end-to-end quality. Although this is an effective approach, it is not scalable to bigger and more extensive vocabularies (i.e., all possible Unicode or word tokens that can be fed to the model). Additionally, the output from the projection operation itself doesn’t contain trainable weights to take advantage of pre-training the model.
Token-free models presented in ByT5 are a good starting point for on-device modeling that can address pre-training and scalability issues without the need to increase the size of the model. This is possible because these approaches treat text inputs as a stream of bytes (each byte has a value that ranges from 0 to 255) that can reduce the vocabulary size for the embedding tables from ~30,000 to 256. Although ByT5 presents a compelling alternative for on-device modeling, going from word-level representation to byte stream representation increases the sequence lengths linearly; with an average word length of four characters and a single character having up to four bytes, the byte sequence length increases proportionally to the word length. This can lead to a significant increase in inference latency and computational costs.
We address this problem by developing and releasing three novel byte-stream sequence models for the SeqFlowLite library (ByteQRNN, ByteTransformer and ByteFunnelTransformer), all of which can be pre-trained on unsupervised data and can be fine-tuned for specific tasks. These models leverage recent innovations introduced by Charformer, including a fast character Transformer-based model that uses a gradient-based subword tokenization (GBST) approach to operate directly at the byte level, as well as a “soft” tokenization approach, which allows us to learn token boundaries and reduce sequence lengths. In this post, we focus on ByteQRNN and demonstrate that the performance of a pre-trained ByteQRNN model is comparable to BERT, despite being 300x smaller.
Sequence Model Architecture We leverage pQRNN, ByT5 and Charformer along with platform optimizations, such as in-training quantization (which tracks minimum and maximum float values for model activations and weights for quantizing the inference model) that reduces model sizes by one-fourth, to develop an end-to-end model called ByteQRNN (shown below). First, we use a ByteSplitter operation to split the input string into a byte stream and feed it to a smaller embedding table that has a vocabulary size of 259 (256 + 3 additional meta tokens).
The output from the embedding layer is fed to the GBST layer, which is equipped with in-training quantization and combines byte-level representations with the efficiency of subword tokenization while enabling end-to-end learning of latent subwords. We “soft” tokenize the byte stream sequences by enumerating and combining each subword block length with scores (computed with a quantized dense layer) at each strided token position (i.e., at token positions that are selected at regular intervals). Next, we downsample the byte stream to manageable sequence length and feed it to the encoder layer.
The output from the GBST layer can be downsampled to a lower sequence length for efficient encoder computation or can be used by an encoder, like Funnel Transformer, which pools the query length and reduces the self-attention computation to create the ByteFunnelTransformer model. The encoder in the end-to-end model can be replaced with any other encoder layer, such as the Transformer from the SeqFlowLite library, to create a ByteTransformer model.
A diagram of a generic end-to-end sequence model using byte stream input. The ByteQRNN model uses a QRNN encoder from the SeqFlowLite library.
In addition to the input embeddings (i.e., the output from the embedding layer described above), we go a step further to build an effective sequence-to-sequence (seq2seq) model. We do so by taking ByteQRNN and adding a Transformer-based decoder model along with a quantized beam search (or tree exploration) to go with it. The quantized beam search module reduces the inference latency when generating decoder outputs by computing the most likely beams (i.e., possible output sequences) using the logarithmic sum of previous and current probabilities and returns the resulting top beams. Here the system uses a more efficient 8-bit integer (uint8) format, compared to a typical single-precision floating-point format (float32) model.
The decoder Transformer model uses a merged attention sublayer (MAtt) to reduce the complexity of the decoder self-attention from quadratic to linear, thereby lowering the end-to-end latency. For each decoding step, MAtt uses a fixed-size cache for decoder self-attention compared to the increasing cache size of a traditional transformer decoder. The following figure illustrates how the beam search module interacts with the decoder layer to generate output tokens on-device using an edge device (e.g., mobile phones, tablets, etc.).
A comparison of cloud server decoding and on-device (edge device) implementation. Left: Cloud server beam search employs a Transformer-based decoder model with quadratic time self-attention in float32, which has an increasing cache size for each decoding step. Right: The edge device implementation employs a quantized beam search module along with a fixed-size cache and a linear time self-attention computation.
Evaluation After developing ByteQRNN, we evaluate its performance on the civil_comments dataset using the area under the curve (AUC) metric and compare it to a pre-trained ByteQRNN and BERT (shown below). We demonstrate that the fine-tuned ByteQRNN improves the overall quality and brings its performance closer to the BERT models, despite being 300x smaller. Since SeqFlowLite models support in-training quantization that reduces model sizes by one-fourth, the resulting models scale well to low-compute devices. We chose multilingual data sources that related to the task for pre-training both BERT and byte stream models to achieve the best possible performance.
Comparison of ByteQRNN with fine-tuned ByteQRNN and BERT on the civil_comments dataset.
Conclusion Following up on our previous work with pQRNN, we evaluate byte stream models for on-device use to enable pre-training and thereby improve model performance for on-device deployment. We present an evaluation for ByteQRNN with and without pre-training and demonstrate that the performance of the pre-trained ByteQRNN is comparable to BERT, despite being 300x smaller. In addition to ByteQRNN, we are also releasing ByteTransformer and ByteFunnelTransformer, two models which use different encoders, along with the merged attention decoder model and the beam search driver to run the inference through the SeqFlowLite library. We hope these models will provide researchers and product developers with valuable resources for future on-device deployments.
Acknowledgements We would like to thank Khoa Trinh, Jeongwoo Ko, Peter Young and Yicheng Fan for helping with open-sourcing and evaluating the model. Thanks to Prabhu Kaliamoorthi for all the brainstorming and ideation. Thanks to Vinh Tran, Jai Gupta and Yi Tay for their help with pre-training byte stream models. Thanks to Ruoxin Sang, Haoyu Zhang, Ce Zheng, Chuanhao Zhuge and Jieying Luo for helping with the TPU training. Many thanks to Erik Vee, Ravi Kumar and the Learn2Compress leadership for sponsoring the project and their support and encouragement. Finally, we would like to thank Tom Small for the animated figure used in this post.