Tag Archives: machine learning

Nested Hierarchical Transformer: Towards Accurate, Data-Efficient, and Interpretable Visual Understanding

In visual understanding, the Visual Transformer (ViT) and its variants have received significant attention recently due to their superior performance on many core visual applications, such as image classification, object detection, and video understanding. The core idea of ViT is to utilize the power of self-attention layers to learn global relationships between small patches of images. However, the number of connections between patches increases quadratically with image size. Such a design has been observed to be data inefficient — although the original ViT can perform better than convolutional networks with hundreds of millions of images for pre-training, such a data requirement is not always practical, and it still underperforms compared to convolutional networks when given less data. Many are exploring to find more suitable architectural re-designs that can learn visual representations effectively, such as by adding convolutional layers and building hierarchical structures with local self-attention.

The principle of hierarchical structure is one of the core ideas in vision models, where bottom layers learn more local object structures on the high-dimensional pixel space and top layers learn more abstracted and high-level knowledge at low-dimensional feature space. Existing ViT-based methods focus on designing a variety of modifications inside self-attention layers to achieve such a hierarchy, but while these offer promising performance improvements, they often require substantial architectural re-designs. Moreover, these approaches lack an interpretable design, so it is difficult to explain the inner-workings of trained models.

To address these challenges, in “Nested Hierarchical Transformer: Towards Accurate, Data-Efficient and Interpretable Visual Understanding”, we present a rethinking of existing hierarchical structure–driven designs, and provide a novel and orthogonal approach to significantly simplify them. The central idea of this work is to decouple feature learning and feature abstraction (pooling) components: nested transformer layers encode visual knowledge of image patches separately, and then the processed information is aggregated. This process is repeated in a hierarchical manner, resulting in a pyramid network structure. The resulting architecture achieves competitive results on ImageNet and outperforms results on data-efficient benchmarks. We have shown such a design can meaningfully improve data efficiency with faster convergence and provide valuable interpretability benefits. Moreover, we introduce GradCAT, a new technique for interpreting the decision process of a trained model at inference time.

Architecture Design
The overall architecture is simple to implement by adding just a few lines of Python code to the source code of the original ViT. The original ViT architecture divides an input image into small patches, projects pixels of each patch to a vector with predefined dimension, and then feeds the sequences of all vectors to the overall ViT architecture containing multiple stacked identical transformer layers. While every layer in ViT processes information of the whole image, with this new method, stacked transformer layers are used to process only a region (i.e., block) of the image containing a few spatially adjacent image patches. This step is independent for each block and is also where feature learning occurs. Finally, a new computational layer called block aggregation then combines all of the spatially adjacent blocks. After each block aggregation, the features corresponding to four spatially adjacent blocks are fed to another module with a stack of transformer layers, which then process those four blocks jointly. This design naturally builds a pyramid hierarchical structure of the network, where bottom layers can focus on local features (such as textures) and upper layers focus on global features (such as object shape) at reduced dimensionality thanks to the block aggregation.

A visualization of the network processing an image: Given an input image, the network first partitions images into blocks, where each block contains 4 image patches. Image patches in every block are linearly projected as vectors and processed by a stack of identical transformer layers. Then the proposed block aggregation layer aggregates information from each block and reduces its spatial size by 4 times. The number of blocks is reduced to 1 at the top hierarchy and classification is conducted after the output of it.

Interpretability
This architecture has a non-overlapping information processing mechanism, independent at every node. This design resembles a decision tree-like structure, which manifests unique interpretability capabilities because every tree node contains independent information of an image block that is being received by its parent nodes. We can trace the information flow through the nodes to understand the importance of each feature. In addition, our hierarchical structure retains the spatial structure of images throughout the network, leading to learned spatial feature maps that are effective for interpretation. Below we showcase two kinds of visual interpretability.

First, we present a method to interpret the trained model on test images, called gradient-based class-aware tree-traversal (GradCAT). GradCAT traces the feature importance of each block (a tree node) from top to bottom of the hierarchy structure. The main idea is to find the most valuable traversal from the root node at the top layer to a child node at the bottom layer that contributes the most to the classification outcomes. Since each node processes information from a certain region of the image, such traversal can be easily mapped to the image space for interpretation (as shown by the overlaid dots and lines in the image below).

The following is an example of the model's top-4 predictions and corresponding interpretability results on the left input image (containing 4 animals). As shown below, GradCAT highlights the decision path along the hierarchical structure as well as the corresponding visual cues in local image regions on the images.

Given the left input image (containing four objects), the figure visualizes the interpretability results of the top-4 prediction classes. The traversal locates the model decision path along the tree and simultaneously locates the corresponding image patch (shown by the dotted line on images) that has the highest impact to the predicted target class.

Moreover, the following figures visualize results on the ImageNet validation set and show how this approach enables some intuitive observations. For instance, the example of the lighter below (upper left panel) is particularly interesting because the ground truth class — lighter/matchstick — actually defines the bottom-right matchstick object, while the most salient visual features (with the highest node values) are actually from the upper-left red light, which conceptually shares visual cues with a lighter. This can also be seen from the overlaid red lines, which indicate the image patches with the highest impact on the prediction. Thus, although the visual cue is a mistake, the output prediction is correct. In addition, the four child nodes of the wooden spoon below have similar feature importance values (see numbers visualized in the nodes; higher indicates more importance), which is because the wooden texture of the table is similar to that of the spoon.

Visualization of the results obtained by the proposed GradCAT. Images are from the ImageNet validation dataset.

Second, different from the original ViT, our hierarchical architecture retains spatial relationships in learned representations. The top layers output low-resolution features maps of input images, enabling the model to easily perform attention-based interpretation by applying Class Attention Map (CAM) on the learned representations at the top hierarchical level. This enables high-quality weakly-supervised object localization with just image-level labels. See the following figure for examples.

Visualization of CAM-based attention results. Warmer colors indicate higher attention. Images are from the ImageNet validation dataset.

Convergence Advantages
With this design, feature learning only happens at local regions independently, and feature abstraction happens inside the aggregation function. This design and simple implementation is general enough for other types of visual understanding tasks beyond classification. It also improves the model convergence speed greatly, significantly reducing the training time to reach the desired maximum accuracy.

We validate this advantage in two ways. First, we compare the ViT structure on the ImageNet accuracy with a different number of total training epochs. The results are shown on the left side of the figure below, demonstrating much faster convergence than the original ViT, e.g., around 20% improvement in accuracy over ViT with 30 total training epochs.

Second, we modify the architecture to conduct unconditional image generation tasks, since training ViT-based models for image generation tasks is challenging due to convergence and speed issues. Creating such a generator is straightforward by transposing the proposed architecture: the input is an embedding vector, the output is a full image in RGB channels, and the block aggregation is replaced by a block de-aggregation component supported by Pixel Shuffling. Surprisingly, we find our generator is easy to train and demonstrates faster convergence speed, as well as better FID score (which measures how similar generated images are to real ones), than the capacity-comparable SAGAN.

Left: ImageNet accuracy given different number of total training epochs compared with standard ViT architecture. Right: ImageNet 64x64 image generation FID scores (lower is better) with single 1000-epoch training. On both tasks, our method shows better convergence speed.

Conclusion
In this work we demonstrate the simple idea that decoupled feature learning and feature information extraction in this nested hierarchy design leads to better feature interpretability through a new gradient-based class-aware tree traversal method. Moreover, the architecture improves convergence on not only classification tasks but also image generation tasks. The proposed idea is focusing on aggregation function and thereby is orthogonal to advanced architecture design for self-attention. We hope this new research encourages future architecture designers to explore more interpretable and data-efficient ViT-based models for visual understanding, like the adoption of this work for high-resolution image generation. We have also released the source code for the image classification portion of this work.

Acknowledgements
We gratefully acknowledge the contributions of other co-authors, including Han Zhang, Long Zhao, Ting Chen, Sercan Arik, Tomas Pfister. We also thank Xiaohua Zhai, Jeremy Kubica, Kihyuk Sohn, and Madeleine Udell for the valuable feedback of the work.

Source: Google AI Blog


Nested Hierarchical Transformer: Towards Accurate, Data-Efficient, and Interpretable Visual Understanding

In visual understanding, the Visual Transformer (ViT) and its variants have received significant attention recently due to their superior performance on many core visual applications, such as image classification, object detection, and video understanding. The core idea of ViT is to utilize the power of self-attention layers to learn global relationships between small patches of images. However, the number of connections between patches increases quadratically with image size. Such a design has been observed to be data inefficient — although the original ViT can perform better than convolutional networks with hundreds of millions of images for pre-training, such a data requirement is not always practical, and it still underperforms compared to convolutional networks when given less data. Many are exploring to find more suitable architectural re-designs that can learn visual representations effectively, such as by adding convolutional layers and building hierarchical structures with local self-attention.

The principle of hierarchical structure is one of the core ideas in vision models, where bottom layers learn more local object structures on the high-dimensional pixel space and top layers learn more abstracted and high-level knowledge at low-dimensional feature space. Existing ViT-based methods focus on designing a variety of modifications inside self-attention layers to achieve such a hierarchy, but while these offer promising performance improvements, they often require substantial architectural re-designs. Moreover, these approaches lack an interpretable design, so it is difficult to explain the inner-workings of trained models.

To address these challenges, in “Nested Hierarchical Transformer: Towards Accurate, Data-Efficient and Interpretable Visual Understanding”, we present a rethinking of existing hierarchical structure–driven designs, and provide a novel and orthogonal approach to significantly simplify them. The central idea of this work is to decouple feature learning and feature abstraction (pooling) components: nested transformer layers encode visual knowledge of image patches separately, and then the processed information is aggregated. This process is repeated in a hierarchical manner, resulting in a pyramid network structure. The resulting architecture achieves competitive results on ImageNet and outperforms results on data-efficient benchmarks. We have shown such a design can meaningfully improve data efficiency with faster convergence and provide valuable interpretability benefits. Moreover, we introduce GradCAT, a new technique for interpreting the decision process of a trained model at inference time.

Architecture Design
The overall architecture is simple to implement by adding just a few lines of Python code to the source code of the original ViT. The original ViT architecture divides an input image into small patches, projects pixels of each patch to a vector with predefined dimension, and then feeds the sequences of all vectors to the overall ViT architecture containing multiple stacked identical transformer layers. While every layer in ViT processes information of the whole image, with this new method, stacked transformer layers are used to process only a region (i.e., block) of the image containing a few spatially adjacent image patches. This step is independent for each block and is also where feature learning occurs. Finally, a new computational layer called block aggregation then combines all of the spatially adjacent blocks. After each block aggregation, the features corresponding to four spatially adjacent blocks are fed to another module with a stack of transformer layers, which then process those four blocks jointly. This design naturally builds a pyramid hierarchical structure of the network, where bottom layers can focus on local features (such as textures) and upper layers focus on global features (such as object shape) at reduced dimensionality thanks to the block aggregation.

A visualization of the network processing an image: Given an input image, the network first partitions images into blocks, where each block contains 4 image patches. Image patches in every block are linearly projected as vectors and processed by a stack of identical transformer layers. Then the proposed block aggregation layer aggregates information from each block and reduces its spatial size by 4 times. The number of blocks is reduced to 1 at the top hierarchy and classification is conducted after the output of it.

Interpretability
This architecture has a non-overlapping information processing mechanism, independent at every node. This design resembles a decision tree-like structure, which manifests unique interpretability capabilities because every tree node contains independent information of an image block that is being received by its parent nodes. We can trace the information flow through the nodes to understand the importance of each feature. In addition, our hierarchical structure retains the spatial structure of images throughout the network, leading to learned spatial feature maps that are effective for interpretation. Below we showcase two kinds of visual interpretability.

First, we present a method to interpret the trained model on test images, called gradient-based class-aware tree-traversal (GradCAT). GradCAT traces the feature importance of each block (a tree node) from top to bottom of the hierarchy structure. The main idea is to find the most valuable traversal from the root node at the top layer to a child node at the bottom layer that contributes the most to the classification outcomes. Since each node processes information from a certain region of the image, such traversal can be easily mapped to the image space for interpretation (as shown by the overlaid dots and lines in the image below).

The following is an example of the model's top-4 predictions and corresponding interpretability results on the left input image (containing 4 animals). As shown below, GradCAT highlights the decision path along the hierarchical structure as well as the corresponding visual cues in local image regions on the images.

Given the left input image (containing four objects), the figure visualizes the interpretability results of the top-4 prediction classes. The traversal locates the model decision path along the tree and simultaneously locates the corresponding image patch (shown by the dotted line on images) that has the highest impact to the predicted target class.

Moreover, the following figures visualize results on the ImageNet validation set and show how this approach enables some intuitive observations. For instance, the example of the lighter below (upper left panel) is particularly interesting because the ground truth class — lighter/matchstick — actually defines the bottom-right matchstick object, while the most salient visual features (with the highest node values) are actually from the upper-left red light, which conceptually shares visual cues with a lighter. This can also be seen from the overlaid red lines, which indicate the image patches with the highest impact on the prediction. Thus, although the visual cue is a mistake, the output prediction is correct. In addition, the four child nodes of the wooden spoon below have similar feature importance values (see numbers visualized in the nodes; higher indicates more importance), which is because the wooden texture of the table is similar to that of the spoon.

Visualization of the results obtained by the proposed GradCAT. Images are from the ImageNet validation dataset.

Second, different from the original ViT, our hierarchical architecture retains spatial relationships in learned representations. The top layers output low-resolution features maps of input images, enabling the model to easily perform attention-based interpretation by applying Class Attention Map (CAM) on the learned representations at the top hierarchical level. This enables high-quality weakly-supervised object localization with just image-level labels. See the following figure for examples.

Visualization of CAM-based attention results. Warmer colors indicate higher attention. Images are from the ImageNet validation dataset.

Convergence Advantages
With this design, feature learning only happens at local regions independently, and feature abstraction happens inside the aggregation function. This design and simple implementation is general enough for other types of visual understanding tasks beyond classification. It also improves the model convergence speed greatly, significantly reducing the training time to reach the desired maximum accuracy.

We validate this advantage in two ways. First, we compare the ViT structure on the ImageNet accuracy with a different number of total training epochs. The results are shown on the left side of the figure below, demonstrating much faster convergence than the original ViT, e.g., around 20% improvement in accuracy over ViT with 30 total training epochs.

Second, we modify the architecture to conduct unconditional image generation tasks, since training ViT-based models for image generation tasks is challenging due to convergence and speed issues. Creating such a generator is straightforward by transposing the proposed architecture: the input is an embedding vector, the output is a full image in RGB channels, and the block aggregation is replaced by a block de-aggregation component supported by Pixel Shuffling. Surprisingly, we find our generator is easy to train and demonstrates faster convergence speed, as well as better FID score (which measures how similar generated images are to real ones), than the capacity-comparable SAGAN.

Left: ImageNet accuracy given different number of total training epochs compared with standard ViT architecture. Right: ImageNet 64x64 image generation FID scores (lower is better) with single 1000-epoch training. On both tasks, our method shows better convergence speed.

Conclusion
In this work we demonstrate the simple idea that decoupled feature learning and feature information extraction in this nested hierarchy design leads to better feature interpretability through a new gradient-based class-aware tree traversal method. Moreover, the architecture improves convergence on not only classification tasks but also image generation tasks. The proposed idea is focusing on aggregation function and thereby is orthogonal to advanced architecture design for self-attention. We hope this new research encourages future architecture designers to explore more interpretable and data-efficient ViT-based models for visual understanding, like the adoption of this work for high-resolution image generation. We have also released the source code for the image classification portion of this work.

Acknowledgements
We gratefully acknowledge the contributions of other co-authors, including Han Zhang, Long Zhao, Ting Chen, Sercan Arik, Tomas Pfister. We also thank Xiaohua Zhai, Jeremy Kubica, Kihyuk Sohn, and Madeleine Udell for the valuable feedback of the work.

Source: Google AI Blog


Scaling Vision with Sparse Mixture of Experts

Advances in deep learning over the last few decades have been driven by a few key elements. With a small number of simple but flexible mechanisms (i.e., inductive biases such as convolutions or sequence attention), increasingly large datasets, and more specialized hardware, neural networks can now achieve impressive results on a wide range of tasks, such as image classification, machine translation, and protein folding prediction.

However, the use of large models and datasets comes at the expense of significant computational requirements. Yet, recent works suggest that large model sizes might be necessary for strong generalization and robustness, so training large models while limiting resource requirements is becoming increasingly important. One promising approach involves the use of conditional computation: rather than activating the whole network for every single input, different parts of the model are activated for different inputs. This paradigm has been featured in the Pathways vision and recent works on large language models, while it has not been well explored in the context of computer vision.

In “Scaling Vision with Sparse Mixture of Experts”, we present V-MoE, a new vision architecture based on a sparse mixture of experts, which we then use to train the largest vision model to date. We transfer V-MoE to ImageNet and demonstrate matching state-of-the-art accuracy while using about 50% fewer resources than models of comparable performance. We have also open-sourced the code to train sparse models and provided several pre-trained models.

Vision Mixture of Experts (V-MoEs)
Vision Transformers (ViT) have emerged as one of the best architectures for vision tasks. ViT first partitions an image into equally-sized square patches. These are called tokens, a term inherited from language models. Still, compared to the largest language models, ViT models are several orders of magnitude smaller in terms of number of parameters and compute.

To massively scale vision models, we replace some dense feedforward layers (FFN) in the ViT architecture with a sparse mixture of independent FFNs (which we call experts). A learnable router layer selects which experts are chosen (and how they are weighted) for every individual token. That is, different tokens from the same image may be routed to different experts. Each token is only routed to at most K (typically 1 or 2) experts, among a total of E experts (in our experiments, E is typically 32). This allows scaling the model’s size while keeping its computation per token roughly constant. The figure below shows the structure of the encoder blocks in more detail.

V-MoE Transformer Encoder block.

Experimental Results
We first pre-train the model once on JFT-300M, a large dataset of images. The left plot below shows our pre-training results for models of all sizes: from the small S/32 to the huge H/14.

We then transfer the model to new downstream tasks (such as ImageNet), by using a new head (the last layer in a model). We explore two transfer setups: either fine-tuning the entire model on all available examples of the new task, or freezing the pre-trained network and tuning only the new head using a few examples (known as few-shot transfer). The right plot in the figure below summarizes our transfer results to ImageNet, training on only 5 images per class (called 5-shot transfer).

JFT-300M Precision@1 and ImageNet 5-shot accuracy. Colors represent different ViT variants and markers represent either standard ViT (●), or V-MoEs (▸) with expert layers on the last n even blocks. We set n=2 for all models, except V-MoE-H where n=5. Higher indicates better performance, with more efficient models being to the left.

In both cases, the sparse model strongly outperforms its dense counterpart at a given amount of training compute (shown by the V-MoE line being above the ViT line), or achieves similar performance much faster (shown by the V-MoE line being to the left of the ViT line).

To explore the limits of vision models, we trained a 15-billion parameter model with 24 MoE layers (out of 48 blocks) on an extended version of JFT-300M. This massive model — the largest to date in vision as far as we know — achieved 90.35% test accuracy on ImageNet after fine-tuning, near the current state-of-the-art.

Priority Routing
In practice, due to hardware constraints, it is not efficient to use buffers with a dynamic size, so models typically use a pre-defined buffer capacity for each expert. Assigned tokens beyond this capacity are dropped and not processed once the expert becomes "full". As a consequence, higher capacities yield higher accuracy, but they are also more computationally expensive.

We leverage this implementation constraint to make V-MoEs faster at inference time. By decreasing the total combined buffer capacity below the number of tokens to be processed, the network is forced to skip processing some tokens in the expert layers. Instead of choosing the tokens to skip in some arbitrary fashion (as previous works did), the model learns to sort tokens according to an importance score. This maintains high quality predictions while saving a lot of compute. We refer to this approach as Batch Priority Routing (BPR), illustrated below.

Under high capacity, both vanilla and priority routing work well as all patches are processed. However, when the buffer size is reduced to save compute, vanilla routing selects arbitrary patches to process, often leading to poor predictions. BPR smartly prioritizes important patches resulting in better predictions at lower computational costs.

Dropping the right tokens turns out to be essential to deliver high-quality and more efficient inference predictions. When the expert capacity decreases, performance quickly decreases with the vanilla routing mechanism. Conversely, BPR is much more robust to low capacities.

Performance versus inference capacity buffer size (or ratio) C for a V-MoE-H/14 model with K=2. Even for large C’s, BPR improves performance; at low C the difference is quite significant. BPR is competitive with dense models (ViT-H/14) by processing only 15-30% of the tokens.

Overall, we observed that V-MoEs are highly flexible at inference time: for instance, one can decrease the number of selected experts per token to save time and compute, without any further training on the model weights.

Exploring V-MoEs
Because much is yet to be discovered about the internal workings of sparse networks, we also explored the routing patterns of the V-MoE.

One hypothesis is that routers would learn to discriminate and assign tokens to experts based on some semantic grounds (the “car” expert, the “animal” experts, and so on). To test this, below we show plots for two different MoE layers (a very early-on one, and another closer to the head). The x-axis corresponds to each of the 32 experts, and the y-axis shows the ID of the image classes (from 1 to 1000). Each entry in the plot shows how often an expert was selected for tokens corresponding to the specific image class, with darker colors indicating higher frequency. While in the early layers there is little correlation, later in the network, each expert receives and processes tokens from only a handful of classes. Therefore, we can conclude that some semantic clustering of the patches emerges in the deeper layers of the network.

Higher routing decisions correlate with image classes. We show two MoE layers of a V-MoE-H/14. The x-axis corresponds to the 32 experts in a layer. The y-axis are the 1000 ImageNet classes; orderings for both axes are different across plots (to highlight correlations). For each pair (expert e, class c) we show the average routing weight for the tokens corresponding to all images with class c for that particular expert e.

Final Thoughts
We train very large vision models using conditional computation, delivering significant improvements in representation and transfer learning for relatively little training cost. Alongside V-MoE, we introduced BPR, which requires the model to process only the most useful tokens in the expert layers.

We believe this is just the beginning of conditional computation at scale for computer vision; extensions include multi-modal and multi-task models, scaling up the expert count, and improving transfer of the representations produced by sparse models. Heterogeneous expert architectures and conditional variable-length routes are also promising directions. Sparse models can especially help in data rich domains such as large-scale video modeling. We hope our open-source code and models help attract and engage researchers new to this field.

Acknowledgments
We thank our co-authors: Basil Mustafa, Maxim Neumann, Rodolphe Jenatton, André Susano Pinto, Daniel Keysers, and Neil Houlsby. We thank Alex Kolesnikov, Lucas Beyer, and Xiaohua Zhai for providing continuous help and details about scaling ViT models. We are also grateful to Josip Djolonga, Ilya Tolstikhin, Liam Fedus, and Barret Zoph for feedback on the paper; James Bradbury, Roy Frostig, Blake Hechtman, Dmitry Lepikhin, Anselm Levskaya, and Parker Schuh for invaluable support helping us run our JAX models efficiently on TPUs; and many others from the Brain team for their support. Finally, we would also like to thank and acknowledge Tom Small for the awesome animated figure used in this post.

Source: Google AI Blog


Prediction Framework, a time saver for Data Science prediction projects

Posted by Álvaro Lamas, Héctor Parra, Jaime Martínez, Julia Hernández, Miguel Fernandes, Pablo Gil

Acquiring high value customers using predicted Lifetime Value, taking specific actions on high propensity of churn users, generating and activating audiences based on machine learning processed signals…All of those marketing scenarios require of analyzing first party data, performing predictions on the data and activating the results into the different marketing platforms like Google Ads as frequently as possible to keep the data fresh.

Feeding marketing platforms like Google Ads on a regular and frequent basis, requires a robust, report oriented and cost reduced ETL & prediction pipeline. These pipelines are very similar regardless of the use case and it’s very easy to fall into reinventing the wheel every time or manually copy & paste structural code increasing the risk of introducing errors.

Wouldn't it be great to have a common reusable structure and just add the specific code for each of the stages?

Here is where Prediction Framework plays a key role in helping you implement and accelerate your first-party data prediction projects by providing the backbone elements of the predictive process.

Prediction Framework is a fully customizable pipeline that allows you to simplify the implementation of prediction projects. You only need to have the input data source, the logic to extract and process the data and a Vertex AutoML model ready to use along with the right feature list, and the framework will be in charge of creating and deploying the required artifacts. With a simple configuration, all the common artifacts of the different stages of this type of projects will be created and deployed for you: data extraction, data preparation (aka feature engineering), filtering, prediction and post-processing, in addition to some other operational functionality including backfilling, throttling (for API limits), synchronization, storage and reporting.

The Prediction Framework was built to be hosted in the Google Cloud Platform and it makes use of Cloud Functions to do all the data processing (extraction, preparation, filtering and post-prediction processing), Firestore, Pub/Sub and Schedulers for the throttling system and to coordinate the different phases of the predictive process, Vertex AutoML to host your machine learning model and BigQuery as the final storage of your predictions.

Prediction Framework Architecture

To get involved and start using the Prediction Framework, a configuration file needs to be prepared with some environment variables about the Google Cloud Project to be used, the data sources, the ML model to make the predictions and the scheduler for the throttling system. In addition, custom queries for the data extraction, preparation, filtering and post-processing need to be added in the deploy files customization. Then, the deployment is done automatically using a deployment script provided by the tool.

Once deployed, all the stages will be executed one after the other, storing the intermediate and final data in the BigQuery tables:

  • Extract: this step will, on a timely basis, query the transactions from the data source, corresponding to the run date (scheduler or backfill run date) and will store them in a new table into the local project BigQuery.
  • Prepare: immediately after the extract of the transactions for one specific date is available, the data will be picked up from the local BigQuery and processed according to the specs of the model. Once the data is processed, it will be stored in a new table into the local project BigQuery.
  • Filter: this step will query the data stored by the prepare process and will filter the required data and store it into the local project BigQuery. (i.e only taking into consideration new customers transactionsWhat a new customer is up to the instantiation of the framework for the specific use case. Will be covered later).
  • Predict: once the new customers are stored, this step will read them from BigQuery and call the prediction using Vertex API. A formula based on the result of the prediction could be applied to tune the value or to apply thresholds. Once the data is ready, it will be stored into the BigQuery within the target project.
  • Post_process: A formula could be applied to the AutoML batch results to tune the value or to apply thresholds. Once the data is ready, it will be stored into the BigQuery within the target project.

One of the powerful features of the prediction framework is that it allows backfilling directly from the BigQuery user interface, so in case you’d need to reprocess a whole period of time, it could be done in literally 4 clicks.

In summary: Prediction Framework simplifies the implementation of first-party data prediction projects, saving time and minimizing errors of manual deployments of recurrent architectures.

For additional information and to start experimenting, you can visit the Prediction Framework repository on Github.

Prediction Framework, a time saver for Data Science prediction projects

Posted by Álvaro Lamas, Héctor Parra, Jaime Martínez, Julia Hernández, Miguel Fernandes, Pablo Gil

Acquiring high value customers using predicted Lifetime Value, taking specific actions on high propensity of churn users, generating and activating audiences based on machine learning processed signals…All of those marketing scenarios require of analyzing first party data, performing predictions on the data and activating the results into the different marketing platforms like Google Ads as frequently as possible to keep the data fresh.

Feeding marketing platforms like Google Ads on a regular and frequent basis, requires a robust, report oriented and cost reduced ETL & prediction pipeline. These pipelines are very similar regardless of the use case and it’s very easy to fall into reinventing the wheel every time or manually copy & paste structural code increasing the risk of introducing errors.

Wouldn't it be great to have a common reusable structure and just add the specific code for each of the stages?

Here is where Prediction Framework plays a key role in helping you implement and accelerate your first-party data prediction projects by providing the backbone elements of the predictive process.

Prediction Framework is a fully customizable pipeline that allows you to simplify the implementation of prediction projects. You only need to have the input data source, the logic to extract and process the data and a Vertex AutoML model ready to use along with the right feature list, and the framework will be in charge of creating and deploying the required artifacts. With a simple configuration, all the common artifacts of the different stages of this type of projects will be created and deployed for you: data extraction, data preparation (aka feature engineering), filtering, prediction and post-processing, in addition to some other operational functionality including backfilling, throttling (for API limits), synchronization, storage and reporting.

The Prediction Framework was built to be hosted in the Google Cloud Platform and it makes use of Cloud Functions to do all the data processing (extraction, preparation, filtering and post-prediction processing), Firestore, Pub/Sub and Schedulers for the throttling system and to coordinate the different phases of the predictive process, Vertex AutoML to host your machine learning model and BigQuery as the final storage of your predictions.

Prediction Framework Architecture

To get involved and start using the Prediction Framework, a configuration file needs to be prepared with some environment variables about the Google Cloud Project to be used, the data sources, the ML model to make the predictions and the scheduler for the throttling system. In addition, custom queries for the data extraction, preparation, filtering and post-processing need to be added in the deploy files customization. Then, the deployment is done automatically using a deployment script provided by the tool.

Once deployed, all the stages will be executed one after the other, storing the intermediate and final data in the BigQuery tables:

  • Extract: this step will, on a timely basis, query the transactions from the data source, corresponding to the run date (scheduler or backfill run date) and will store them in a new table into the local project BigQuery.
  • Prepare: immediately after the extract of the transactions for one specific date is available, the data will be picked up from the local BigQuery and processed according to the specs of the model. Once the data is processed, it will be stored in a new table into the local project BigQuery.
  • Filter: this step will query the data stored by the prepare process and will filter the required data and store it into the local project BigQuery. (i.e only taking into consideration new customers transactionsWhat a new customer is up to the instantiation of the framework for the specific use case. Will be covered later).
  • Predict: once the new customers are stored, this step will read them from BigQuery and call the prediction using Vertex API. A formula based on the result of the prediction could be applied to tune the value or to apply thresholds. Once the data is ready, it will be stored into the BigQuery within the target project.
  • Post_process: A formula could be applied to the AutoML batch results to tune the value or to apply thresholds. Once the data is ready, it will be stored into the BigQuery within the target project.

One of the powerful features of the prediction framework is that it allows backfilling directly from the BigQuery user interface, so in case you’d need to reprocess a whole period of time, it could be done in literally 4 clicks.

In summary: Prediction Framework simplifies the implementation of first-party data prediction projects, saving time and minimizing errors of manual deployments of recurrent architectures.

For additional information and to start experimenting, you can visit the Prediction Framework repository on Github.

Training Machine Learning Models More Efficiently with Dataset Distillation

For a machine learning (ML) algorithm to be effective, useful features must be extracted from (often) large amounts of training data. However, this process can be made challenging due to the costs associated with training on such large datasets, both in terms of compute requirements and wall clock time. The idea of distillation plays an important role in these situations by reducing the resources required for the model to be effective. The most widely known form of distillation is model distillation (a.k.a. knowledge distillation), where the predictions of large, complex teacher models are distilled into smaller models.

An alternative option to this model-space approach is dataset distillation [1, 2], in which a large dataset is distilled into a synthetic, smaller dataset. Training a model with such a distilled dataset can reduce the required memory and compute. For example, instead of using all 50,000 images and labels of the CIFAR-10 dataset, one could use a distilled dataset consisting of only 10 synthesized data points (1 image per class) to train an ML model that can still achieve good performance on the unseen test set.

Top: Natural (i.e., unmodified) CIFAR-10 images. Bottom: Distilled dataset (1 image per class) on CIFAR-10 classification task. Using only these 10 synthetic images as training data, a model can achieve test set accuracy of ~51%.

In “Dataset Meta-Learning from Kernel Ridge Regression'', published in ICLR 2021, and “Dataset Distillation with Infinitely Wide Convolutional Networks”, presented at NeurIPS 2021, we introduce two novel dataset distillation algorithms, Kernel Inducing Points (KIP) and Label Solve (LS), which optimize datasets using the loss function arising from kernel regression (a classical machine learning algorithm that fits a linear model to features defined through a kernel). Applying the KIP and LS algorithms, we obtain very efficient distilled datasets for image classification, reducing the datasets to 1, 10, or 50 data points per class while still obtaining state-of-the-art results on a number of benchmark image classification datasets. Additionally, we are also excited to release our distilled datasets to benefit the wider research community.

Methodology
One of the key theoretical insights of deep neural networks (DNN) in recent years has been that increasing the width of DNNs results in more regular behavior that makes them easier to understand. As the width is taken to infinity, DNNs trained by gradient descent converge to the familiar and simpler class of models arising from kernel regression with respect to the neural tangent kernel (NTK), a kernel that measures input similarity by computing dot products of gradients of the neural network. Thanks to the Neural Tangents library, neural kernels for various DNN architectures can be computed in a scalable manner.

We utilized the above infinite-width limit theory of neural networks to tackle dataset distillation. Dataset distillation can be formulated as a two-stage optimization process: an “inner loop” that trains a model on learned data, and an “outer loop” that optimizes the learned data for performance on natural (i.e., unmodified) data. The infinite-width limit replaces the inner loop of training a finite-width neural network with a simple kernel regression. With the addition of a regularizing term, the kernel regression becomes a kernel ridge-regression (KRR) problem. This is a highly valuable outcome because the kernel ridge regressor (i.e., the predictor from the algorithm) has an explicit formula in terms of its training data (unlike a neural network predictor), which means that one can easily optimize the KRR loss function during the outer loop.

The original data labels can be represented by one-hot vectors, i.e., the true label is given a value of 1 and all other labels are given values of 0. Thus, an image of a cat would have the label “cat” assigned a 1 value, while the labels for “dog” and “horse” would be 0. The labels we use involve a subsequent mean-centering step, where we subtract the reciprocal of the number of classes from each component (so 0.1 for 10-way classification) so that the expected value of each label component across the dataset is normalized to zero.

While the labels for natural images appear in this standard form, the labels for our learned distilled datasets are free to be optimized for performance. Having obtained the kernel ridge regressor from the inner loop, the KRR loss function in the outer loop computes the mean-square error between the original labels of natural images and the labels predicted by the kernel ridge regressor. KIP optimizes the support data (images and possibly labels) by minimizing the KRR loss function through gradient-based methods. The Label Solve algorithm directly solves for the set of support labels that minimizes the KRR loss function, generating a unique dense label vector for each (natural) support image.

Example of labels obtained by label solving. Left and Middle: Sample images with possible labels listed below. The raw, one-hot label is shown in blue and the final LS generated dense label is shown in orange. Right: The covariance matrix between original labels and learned labels. Here, 500 labels were distilled from the CIFAR-10 dataset. A test accuracy of 69.7% is achieved using these labels for kernel ridge-regression.

Distributed Computation
For simplicity, we focus on architectures that consist of convolutional neural networks with pooling layers. Specifically, we focus on the so-called “ConvNet” architecture and its variants because it has been featured in other dataset distillation studies. We used a slightly modified version of ConvNet that has a simple architecture given by three blocks of convolution, ReLu, and 2x2 average pooling and then a final linear readout layer, with an additional 3x3 convolution and ReLu layer prepended (see our GitHub for precise details).

ConvNet architecture used in DC/DSA. Ours has an additional 3x3 Conv and ReLu prepended.

To compute the neural kernels needed in our work, we used the Neural Tangents library.

The first stage of this work, in which we applied KRR, focused on fully-connected networks, whose kernel elements are cheap to compute. But a hurdle facing neural kernels for models with convolutional layers plus pooling is that the computation of each kernel element between two images scales as the square of the number of input pixels (due to the capturing of pixel-pixel correlations by the kernel). So, for the second stage of this work, we needed to distribute the computation of the kernel elements and their gradients across many devices.

Distributed computation for large scale metalearning.

We invoke a client-server model of distributed computation in which a server distributes independent workloads to a large pool of client workers. A key part of this is to divide the backpropagation step in a way that is computationally efficient (explained in detail in the paper).

We accomplish this using the open-source tools Courier (part of DeepMind’s Launchpad), which allows us to distribute computations across GPUs working in parallel, and JAX, for which novel usage of the jax.vjp function enables computationally efficient gradients. This distributed framework allows us to utilize hundreds of GPUs per distillation of the dataset, for both the KIP and LS algorithms. Given the compute required for such experiments, we are releasing our distilled datasets to benefit the wider research community.

Examples
Our first set of distilled images above used KIP to distill CIFAR-10 down to 1 image per class while keeping the labels fixed. Next, in the below figure, we compare the test accuracy of training on natural MNIST images, KIP distilled images with labels fixed, and KIP distilled images with labels optimized. We highlight that learning the labels provides an effective, albeit mysterious benefit to distilling datasets. Indeed the resulting set of images provides the best test performance (for infinite-width networks) despite being less interpretable.

MNIST dataset distillation with trainable and non-trainable labels. Top: Natural MNIST data. Middle: Kernel Inducing Point distilled data with fixed labels. Bottom: Kernel Inducing Point distilled data with learned labels.

Results
Our distilled datasets achieve state-of-the-art performance on benchmark image classification datasets, improving performance beyond previous state-of-the-art models that used convolutional architectures, Dataset Condensation (DC) and Dataset Condensation with Differentiable Siamese Augmentation (DSA). In particular, for CIFAR-10 classification tasks, a model trained on a dataset consisting of only 10 distilled data entries (1 image / class, 0.02% of the whole dataset) achieves a 64% test set accuracy. Here, learning labels and an additional image preprocessing step leads to a significant increase in performance beyond the 50% test accuracy shown in our first figure (see our paper for details). With 500 images (50 images / class, 1% of the whole dataset), the model reaches 80% test set accuracy. While these numbers are with respect to neural kernels (using the KRR infinite width limit), these distilled datasets can be used to train finite-width neural networks as well. In particular, for 10 data points on CIFAR-10, a finite-width ConvNet neural network achieves 50% test accuracy with 10 images and 68% test accuracy using 500 images, which are still state-of-the-art results. We provide a simple Colab notebook demonstrating this transfer to a finite-width neural network.

Dataset distillation using Kernel Inducing Points (KIP) with a convolutional architecture outperforms prior state-of-the-art models (DC/DSA) on all benchmark settings on image classification tasks. Label Solve (LS, middle columns) while only distilling information in the labels could often (e.g. CIFAR-10 10, 50 data points per class) outperform prior state-of-the-art models as well.

In some cases, our learned datasets are more effective than a natural dataset one hundred times larger in size.

Conclusion
We believe that our work on dataset distillation opens up many interesting future directions. For instance, our algorithms KIP and LS have demonstrated the effectiveness of using learned labels, an area that remains relatively underexplored. Furthermore, we expect that utilizing efficient kernel approximation methods can help to reduce computational burden and scale up to larger datasets. We hope this work encourages researchers to explore other applications of dataset distillation, including neural architecture search and continual learning, and even potential applications to privacy.

Anyone interested in the KIP and LS learned datasets for further analysis is encouraged to check out our papers [ICLR 2021, NeurIPS 2021] and open-sourced code and datasets available on Github.

Acknowledgement
This project was done in collaboration with Zhourong Chen, Roman Novak and Lechao Xiao. We would like to acknowledge special thanks to Samuel S. Schoenholz, who proposed and helped develop the overall strategy for our distributed KIP learning methodology.


1Now at DeepMind.  

Source: Google AI Blog


More Efficient In-Context Learning with GLaM

Large language models (e.g., GPT-3) have many significant capabilities, such as performing few-shot learning across a wide array of tasks, including reading comprehension and question answering with very few or no training examples. While these models can perform better by simply using more parameters, training and serving these large models can be very computationally intensive. Is it possible to train and use these models more efficiently?

In pursuit of that question, today we introduce the Generalist Language Model (GLaM), a trillion weight model that can be trained and served efficiently (in terms of computation and energy use) thanks to sparsity, and achieves competitive performance on multiple few-shot learning tasks. GLaM’s performance compares favorably to a dense language model, GPT-3 (175B) with significantly improved learning efficiency across 29 public NLP benchmarks in seven categories, spanning language completion, open-domain question answering, and natural language inference tasks.

Dataset
To build GLaM, we began by building a high-quality 1.6 trillion token dataset containing language usage representative of a wide range of downstream use-cases for the model. Web pages constitute the vast quantity of data in this unlabelled corpus, but their quality ranges from professional writing to low-quality comment and forum pages. We then developed a text quality filter that was trained on a collection of text from Wikipedia and books (both of which are generally higher quality sources) to determine the quality of the content for a webpage. Finally, we applied this filter to generate the final subset of webpages and combined this with books and Wikipedia to create the final training dataset.

Model and Architecture
GLaM is a mixture of experts (MoE) model, a type of model that can be thought of as having different submodels (or experts) that are each specialized for different inputs. The experts in each layer are controlled by a gating network that activates experts based on the input data. For each token (generally a word or part of a word), the gating network selects the two most appropriate experts to process the data. The full version of GLaM has 1.2T total parameters across 64 experts per MoE layer with 32 MoE layers in total, but only activates a subnetwork of 97B (8% of 1.2T) parameters per token prediction during inference.

The architecture of GLaM where each input token is dynamically routed to two selected expert networks out of 64 for prediction.

Similar to the GShard MoE Transformer, we replace the single feedforward network (the simplest layer of an artificial neural network, “Feedforward or FFN” in the blue boxes) of every other transformer layer with a MoE layer. This MoE layer has multiple experts, each a feedforward network with identical architecture but different weight parameters. Even though this MoE layer has many more parameters, the experts are sparsely activated, meaning that for a given input token, only two experts are used, giving the model more capacity while limiting computation. During training, each MoE layer's gating network is trained to use its input to activate the best two experts for each token, which are then used for inference. For a MoE layer of E experts, this essentially provides a collection of E×(E-1) different feedforward network combinations (instead of one as in the classic Transformer architecture), leading to more computational flexibility.

The final learned representation of a token will be the weighted combination of the outputs from the two experts. This allows different experts to activate on different types of inputs. To enable scaling to larger models, each expert within the GLaM architecture can span multiple computational devices. We use the GSPMD compiler backend to solve the challenges in scaling the experts and train several variants (based on expert size and number of experts) of this architecture to understand the scaling effects of sparsely activated language models.

Evaluation
We use a zero-shot and one-shot setting where the tasks are never seen during training. The benchmarks for evaluation include (1) cloze and completion tasks [1,2,3]; (2) Open-domain question answering [4,5,6]; (3) Winograd-style tasks [7,8]; (4) commonsense reasoning [9,10,11]; (5) in-context reading comprehension [12,13,14,15,16]; (6) the SuperGLUE tasks; and (7) natural language inference [17]. In total, there are eight natural language generation tasks (NLG) where the generated phrases are evaluated against the ground truth targets via Exact Match (EM) accuracy and F1 measure, and 21 language understanding tasks (NLU) where the prediction from several options is chosen via conditional log-likelihood. Some tasks have variants and SuperGLUE consists of multiple tasks. Both EM accuracy and F1 are scaled from 0 to 100 across all our results and averaged for the NLG score below. The NLU score is an average of accuracy and F1 scores.

Results
GLaM reduces to a basic dense Transformer-based language model architecture when each MoE layer only has one expert. In all experiments, we adopt the notation of (base dense model size) / (number of experts per MoE layer) to describe the GLaM model. For example, 1B/64E represents the architecture of a 1B parameter dense model with every other layer replaced by a 64 expert MoE layer. In the following sections, we explore GLaM’s performance and scaling properties, including baseline dense models trained on the same datasets. Compared with the recently announced Megatron-Turing model, GLaM is on-par on the seven respective tasks if using a 5% margin, while using 5x less computation during inference.

Below, we show the 1.2T-parameter sparsely activated model (GLaM) achieved higher results on average and on more tasks than the 175B-parameter dense GPT-3 model while using less computation during inference.

Average score for GLaM and GPT-3 on NLG (left) and NLU (right) tasks (higher is better).

Below we show a summary of the performance on 29 benchmarks compared to the dense model (GPT-3, 175B). GLaM exceeds or is on-par with the performance of the dense model on almost 80% of zero-shot tasks and almost 90% of one-shot tasks.

Evaluation Higher (>+5%) On-par (within 5%) Lower (<-5%)
Zero-shot 13 11 5
One-shot 14 10 5

Moreover, while the full version of GLaM has 1.2T total parameters, it only activates a subnetwork of 97B parameters (8% of 1.2T) per token during inference.

GLaM (64B/64E) GPT-3 (175B)
Total Parameters 1.162T 0.175T
Activated Parameters 0.097T 0.175T

Scaling Behavior
GLaM has two ways to scale: 1) scale the number of experts per layer, where each expert is hosted within one computation device, or 2) scale the size of each expert to go beyond the limit of a single device. To evaluate the scaling properties, we compare the respective dense model (FFN layers instead of MoE layers) of similar FLOPS per token at inference time.

Average zero-shot and one-shot performance by increasing the size of each expert. The FLOPS per token prediction at inference time increases as the expert size grows.

As shown above, performance across tasks scales with the size of the experts. GLaM sparsely activated models also perform better than dense models for similar FLOPs during inference for generation tasks. For understanding tasks, we observed that they perform similarly at smaller scales, but sparsely activated models outperform at larger scales.

Data Efficiency
Training large language models is computationally intensive, so efficiency improvements are useful to reduce energy consumption.

Below we show the computation costs for the full version of GLaM.

Computation cost in GFLOPS both for inference, per token (left) and for training (right).

These compute costs show that GLaM uses more computation during training since it trains on more tokens, but uses significantly less computation during inference. We show comparisons using different numbers of tokens to train below.

We also evaluated the learning curves of our models compared to the dense baseline.

Average zero-shot and one-shot performance of sparsely-activated and dense models on eight generative tasks as more tokens are processed in training.
Average zero-shot and one-shot performance of sparsely-activated and dense models on 21 understanding tasks as more tokens are processed in training.

The results above show that sparsely activated models need to train with significantly less data than dense models to reach similar zero-shot and one-shot performance, and if the same amount of data is used, sparsely activated models perform significantly better.

Finally, we assessed the energy efficiency of GLaM.

Comparison of power consumption during training.

While GLaM uses more computation during training, thanks to the more efficient software implementation powered by GSPMD and the advantage of TPUv4, it uses less power to train than other models.

Conclusions
Our large-scale sparsely activated language model, GLaM, achieves competitive results on zero-shot and one-shot learning and is a more efficient model than prior monolithic dense counterparts. We also show quantitatively that a high-quality dataset is essential for large language models. We hope that our work will spark more research into compute-efficient language models.

Acknowledgements
We wish to thank Claire Cui, Zhifeng Chen, Yonghui Wu, Quoc Le, Macduff Hughes, Fernando Pereira, Zoubin Ghahramani‎ and Jeff Dean for their support and invaluable input. Special thanks to our collaborators: Yanping Huang, Simon Tong, Yanqi Zhou, Yuanzhong Xu, Dmitry Lepikhin, Orhan Firat, Maxim Krikun, Tao Wang, Noam Shazeer, Barret Zoph, Liam Fedus, Maarten Bosma, Kun Zhang, Emma Wang, David Patterson, Zongwei Zhou, Naveen Kumar, Adams Yu, Laurent Shafey, Jonathan Shen, Ben Lee, Anmol Gulati, David So, Marie Pellat, Kellie Webster, Kevin Robinson, Kathy Meier-Hellstern, Toju Duke, Lucas Disxon, Aakanksha Chowdhery, Sharan Narang, Erica Moreira and Eric Ni for helpful discussions and inspirations; and the larger Google Research team. We would also like to thank Tom Small for the animated figure used in this post.

Source: Google AI Blog


General and Scalable Parallelization for Neural Networks

Scaling neural networks, whether it be the amount of training data used, the model size or the computation being utilized, has been critical for improving model quality in many real-world machine learning applications, such as computer vision, language understanding and neural machine translation. This, in turn, has motivated recent studies to scrutinize the factors that play a critical role in the success of scaling a neural model. Although increasing model capacity can be a sound approach to improve model quality, doing so presents a number of systems and software engineering challenges that must be overcome. For instance, in order to train large models that exceed the memory capacity of an accelerator, it becomes necessary to partition the weights and the computation of the model across multiple accelerators. This process of parallelization increases the network communication overhead and can result in device under-utilization. Moreover, a given algorithm for parallelization, which typically requires a significant amount of engineering effort, may not work with different model architectures.

To address these scaling challenges, we present “GSPMD: General and Scalable Parallelization for ML Computation Graphs”, in which we describe an open-source automatic parallelization system based on the XLA compiler. GSPMD is capable of scaling most deep learning network architectures and has already been applied to many deep learning models, such as GShard-M4, LaMDA, BigSSL, ViT, and MetNet-2, leading to state-of-the-art-results across several domains. GSPMD has also been integrated into multiple ML frameworks, including TensorFlow and JAX, which use XLA as a shared compiler.

Overview
GSPMD separates the task of programming an ML model from the challenge of parallelization. It allows model developers to write programs as if they were run on a single device with very high memory and computation capacity — the user simply needs to add a few lines of annotation code to a subset of critical tensors in the model code to indicate how to partition the tensors. For example, to train a large model-parallel Transformer, one may only need to annotate fewer than 10 tensors (less than 1% of all tensors in the entire computation graph), one line of additional code per tensor. Then GSPMD runs a compiler pass that determines the entire graph's parallelization plan, and transforms it into a mathematically equivalent, parallelized computation that can be executed on each device. This allows users to focus on model building instead of parallelization implementation, and enables easy porting of existing single-device programs to run at a much larger scale.

The separation of model programming and parallelism also allows developers to minimize code duplication. With GSPMD, developers may employ different parallelism algorithms for different use cases without the need to reimplement the model. For example, the model code that powered the GShard-M4 and LaMDA models can apply a variety of parallelization strategies appropriate for different models and cluster sizes with the same model implementation. Similarly, by applying GSPMD, the BigSSL large speech models can share the same implementation with previous smaller models.

Generality and Flexibility
Because different model architectures may be better suited to different parallelization strategies, GSPMD is designed to support a large variety of parallelism algorithms appropriate for different use cases. For example, with smaller models that fit within the memory of a single accelerator, data parallelism is preferred, in which devices train the same model using different input data. In contrast, models that are larger than a single accelerator’s memory capacity are better suited for a pipelining algorithm (like that employed by GPipe) that partitions the model into multiple, sequential stages, or operator-level parallelism (e.g., Mesh-TensorFlow), in which individual computation operators in the model are split into smaller, parallel operators.

GSPMD supports all the above parallelization algorithms with a uniform abstraction and implementation. Moreover, GSPMD supports nested patterns of parallelism. For example, it can be used to partition models into individual pipeline stages, each of which can be further partitioned using operator-level parallelism.

GSPMD also facilitates innovation on parallelism algorithms by allowing performance experts to focus on algorithms that best utilize the hardware, instead of the implementation that involves lots of cross-device communications. For example, for large Transformer models, we found a novel operator-level parallelism algorithm that partitions multiple dimensions of tensors on a 2D mesh of devices. It reduces peak accelerator memory usage linearly with the number of training devices, while maintaining a high utilization of accelerator compute due to its balanced data distribution over multiple dimensions.

To illustrate this, consider a simplified feedforward layer in a Transformer model that has been annotated in the above way. To execute the first matrix multiply on fully partitioned input data, GSPMD applies an MPI-style AllGather communication operator to partially merge with partitioned data from another device. It then executes the matrix multiply locally and produces a partitioned result. Before the second matrix multiply, GSPMD adds another AllGather on the right-hand side input, and executes the matrix multiply locally, yielding intermediate results that will then need to be combined and partitioned. For this, GSPMD adds an MPI-style ReduceScatter communication operator that accumulates and partitions these intermediate results. While the tensors generated with the AllGather operator at each stage are larger than the original partition size, they are short-lived and the corresponding memory buffers will be freed after use, which does not affect peak memory usage in training.

Left: A simplified feedforward layer of a Transformer model. Blue rectangles represent tensors with dashed red & blue lines overlaid representing the desired partitioning across a 2x2 mesh of devices. Right: A single partition, after GSPMD has been applied.

A Transformer Example with Nested Parallelism
As a shared, robust mechanism for different parallelism modes, GSPMD allows users to conveniently switch between modes in different parts of a model. This is particularly valuable for models that may have different components with distinct performance characteristics, for example, multimodal models that handle both images and audio. Consider a model with the Transformer encoder-decoder architecture, which has an embedding layer, an encoder stack with Mixture-of-Expert layers, a decoder stack with dense feedforward layers, and a final softmax layer. In GSPMD, a complex combination of several parallelism modes that treats each layer separately can be achieved with simple configurations.

In the figure below, we show a partitioning strategy over 16 devices organized as a logical 4x4 mesh. Blue represents partitioning along the first mesh dimension X, and yellow represents partitioning along the second mesh dimension Y. X and Y are repurposed for different model components to achieve different parallelism modes. For example, the X dimension is used for data parallelism in the embedding and softmax layers, but used for pipeline parallelism in the encoder and decoder. The Y dimension is also used in different ways to partition the vocabulary, batch or model expert dimensions.

Computation Efficiency
GSPMD provides industry-leading performance in large model training. Parallel models require extra communication to coordinate multiple devices to do the computation. So parallel model efficiency can be estimated by examining the fraction of time spent on communication overhead — the higher percentage utilization and the less time spent on communication, the better. In the recent MLPerf set of performance benchmarks, a BERT-like encoder-only model with ~500 billion parameters to which we applied GSPMD for parallelization over 2048 TPU-V4 chips yielded highly competitive results (see table below), utilizing up to 63% of the peak FLOPS that the TPU-V4s offer. We also provide efficiency benchmarks for some representative large models in the table below. These example model configs are open sourced in the Lingvo framework along with instructions to run them on Google Cloud. More benchmark results can be found in the experiment section of our paper.

Model Family Parameter Count % of model activated* No. of Experts** No. of Layers No. of TPU FLOPS utilization
Dense Decoder (LaMDA) 137B 100% 1 64 1024 TPUv3 56.5%
Dense Encoder (MLPerf-Bert) 480B 100% 1 64 2048 TPUv4 63%
Sparsely Activated Encoder-Decoder (GShard-M4) 577B 0.25% 2048 32 1024 TPUv3 46.8%
Sparsely Activated Decoder 1.2T 8% 64 64 1024 TPUv3 53.8%
*The fraction of the model activated during inference, which is a measure of model sparsity.
**Number of experts included in the Mixture of Experts layer. A value of 1 corresponds to a standard Transformer, without a Mixture of Experts layer.

Conclusion
The ongoing development and success of many useful machine learning applications, such as NLP, speech recognition, machine translation, and autonomous driving, depend on achieving the highest accuracy possible. As this often requires building larger and even more complex models, we are pleased to share the GSPMD paper and the corresponding open-source library to the broader research community, and we hope it is useful for efficient training of large-scale deep neural networks.

Acknowledgements
We wish to thank Claire Cui, Zhifeng Chen, Yonghui Wu, Naveen Kumar, Macduff Hughes, Zoubin Ghahramani and Jeff Dean for their support and invaluable input. Special thanks to our collaborators Dmitry Lepikhin, HyoukJoong Lee, Dehao Chen, Orhan Firat, Maxim Krikun, Blake Hechtman, Rahul Joshi, Andy Li, Tao Wang, Marcello Maggioni, David Majnemer, Noam Shazeer, Ankur Bapna, Sneha Kudugunta, Quoc Le, Mia Chen, Shibo Wang, Jinliang Wei, Ruoming Pang, Zongwei Zhou, David So, Yanqi Zhou, Ben Lee, Jonathan Shen, James Qin, Yu Zhang, Wei Han, Anmol Gulati, Laurent El Shafey, Andrew Dai, Kun Zhang, Nan Du, James Bradbury, Matthew Johnson, Anselm Levskaya, Skye Wanderman-Milne‎, and Qiao Zhang for helpful discussions and inspirations.

Source: Google AI Blog


Decisiveness in Imitation Learning for Robots

Despite considerable progress in robot learning over the past several years, some policies for robotic agents can still struggle to decisively choose actions when trying to imitate precise or complex behaviors. Consider a task in which a robot tries to slide a block across a table to precisely position it into a slot. There are many possible ways to solve this task, each requiring precise movements and corrections. The robot must commit to just one of these options, but must also be capable of changing plans each time the block ends up sliding farther than expected. Although one might expect such a task to be easy, that is often not the case for modern learning-based robots, which often learn behavior that expert observers describe as indecisive or imprecise.

Example of a baseline explicit behavior cloning model struggling on a task where the robot needs to slide a block across a table and then precisely insert it into a fixture.

To encourage robots to be more decisive, researchers often utilize a discretized action space, which forces the robot to choose option A or option B, without oscillating between options. For example, discretization was a key element of our recent Transporter Networks architecture, and is also inherent in many notable achievements by game-playing agents, such as AlphaGo, AlphaStar, and OpenAI’s Dota bot. But discretization brings its own limitations — for robots that operate in the spatially continuous real world, there are at least two downsides to discretization: (i) it limits precision, and (ii) it triggers the curse of dimensionality, since considering discretizations along many different dimensions can dramatically increase memory and compute requirements. Related to this, in 3D computer vision much recent progress has been powered by continuous, rather than discretized, representations.

With the goal of learning decisive policies without the drawbacks of discretization, today we announce our open source implementation of Implicit Behavioral Cloning (Implicit BC), which is a new, simple approach to imitation learning and was presented last week at CoRL 2021. We found that Implicit BC achieves strong results on both simulated benchmark tasks and on real-world robotic tasks that demand precise and decisive behavior. This includes achieving state-of-the-art (SOTA) results on human-expert tasks from our team’s recent benchmark for offline reinforcement learning, D4RL. On six out of seven of these tasks, Implicit BC outperforms the best previous method for offline RL, Conservative Q Learning. Interestingly, Implicit BC achieves these results without requiring any reward information, i.e., it can use relatively simple supervised learning rather than more-complex reinforcement learning.

Implicit Behavioral Cloning
Our approach is a type of behavior cloning, which is arguably the simplest way for robots to learn new skills from demonstrations. In behavior cloning, an agent learns how to mimic an expert’s behavior using standard supervised learning. Traditionally, behavior cloning involves training an explicit neural network (shown below, left), which takes in observations and outputs expert actions.

The key idea behind Implicit BC is to instead train a neural network to take in both observations and actions, and output a single number that is low for expert actions and high for non-expert actions (below, right), turning behavioral cloning into an energy-based modeling problem. After training, the Implicit BC policy generates actions by finding the action input that has the lowest score for a given observation.

Depiction of the difference between explicit (left) and implicit (right) policies. In the implicit policy, the “argmin” means the action that, when paired with a particular observation, minimizes the value of the energy function.

To train Implicit BC models, we use an InfoNCE loss, which trains the network to output low energy for expert actions in the dataset, and high energy for all others (see below). It is interesting to note that this idea of using models that take in both observations and actions is common in reinforcement learning, but not so in supervised policy learning.

Animation of how implicit models can fit discontinuities — in this case, training an implicit model to fit a step (Heaviside) function. Left: 2D plot fitting the black (X) training points — the colors represent the values of the energies (blue is low, brown is high). Middle: 3D plot of the energy model during training. Right: Training loss curve.

Once trained, we find that implicit models are particularly good at precisely modeling discontinuities (above) on which prior explicit models struggle (as in the first figure of this post), resulting in policies that are newly capable of switching decisively between different behaviors.

But why do conventional explicit models struggle? Modern neural networks almost always use continuous activation functions — for example, Tensorflow, Jax, and PyTorch all only ship with continuous activation functions. In attempting to fit discontinuous data, explicit networks built with these activation functions cannot represent discontinuities, so must draw continuous curves between data points. A key aspect of implicit models is that they gain the ability to represent sharp discontinuities, even though the network itself is composed only of continuous layers.

We also establish theoretical foundations for this aspect, specifically a notion of universal approximation. This proves the class of functions that implicit neural networks can represent, which can help justify and guide future research.

Examples of fitting discontinuous functions, for implicit models (top) compared to explicit models (bottom). The red highlighted insets show that implicit models represent discontinuities (a) and (b) while the explicit models must draw continuous lines (c) and (d) in between the discontinuities.

One challenge faced by our initial attempts at this approach was “high action dimensionality”, which means that a robot must decide how to coordinate many motors all at the same time. To scale to high action dimensionality, we use either autoregressive models or Langevin dynamics.

Highlights
In our experiments, we found Implicit BC does particularly well in the real world, including an order of magnitude (10x) better on the 1mm-precision slide-then-insert task compared to a baseline explicit BC model. On this task the implicit model does several consecutive precise adjustments (below) before sliding the block into place. This task demands multiple elements of decisiveness: there are many different possible solutions due to the symmetry of the block and the arbitrary ordering of push maneuvers, and the robot needs to discontinuously decide when the block has been pushed far “enough” before switching to slide it in a different direction. This is in contrast to the indecisiveness that is often associated with continuous-controlled robots.

Example task of sliding a block across a table and precisely inserting it into a slot. These are autonomous behaviors of our Implicit BC policies, using only images (from the shown camera) as input.
A diverse set of different strategies for accomplishing this task. These are autonomous behaviors from our Implicit BC policies, using only images as input.

In another challenging task, the robot needs to sort blocks by color, which presents a large number of possible solutions due to the arbitrary ordering of sorting. On this task the explicit models are customarily indecisive, while implicit models perform considerably better.

Comparison of implicit (left) and explicit (right) BC models on a challenging continuous multi-item sorting task. (4x speed)

In our testing, implicit BC models can also exhibit robust reactive behavior, even when we try to interfere with the robot, despite the model never seeing human hands.

Robust behavior of the implicit BC model despite interfering with the robot.

Overall, we find that Implicit BC policies can achieve strong results compared to state of the art offline reinforcement learning methods across several different task domains. These results include tasks that, challengingly, have either a low number of demonstrations (as few as 19), high observation dimensionality with image-based observations, and/or high action dimensionality up to 30 — which is a large number of actuators to have on a robot.

Policy learning results of Implicit BC compared to baselines across several domains.

Conclusion
Despite its limitations, behavioral cloning with supervised learning remains one of the simplest ways for robots to learn from examples of human behaviors. As we showed here, replacing explicit policies with implicit policies when doing behavioral cloning allows robots to overcome the "struggle of decisiveness", enabling them to imitate much more complex and precise behaviors. While the focus of our results here was on robot learning, the ability of implicit functions to model sharp discontinuities and multimodal labels may have broader interest in other application domains of machine learning as well.

Acknowledgements
Pete and Corey summarized research performed together with other co-authors: Andy Zeng, Oscar Ramirez, Ayzaan Wahid, Laura Downs, Adrian Wong, Johnny Lee, Igor Mordatch, and Jonathan Tompson. The authors would also like to thank Vikas Sindwhani for project direction advice; Steve Xu, Robert Baruch, Arnab Bose for robot software infrastructure; Jake Varley, Alexa Greenberg for ML infrastructure; and Kamyar Ghasemipour, Jon Barron, Eric Jang, Stephen Tu, Sumeet Singh, Jean-Jacques Slotine, Anirudha Majumdar, Vincent Vanhoucke for helpful feedback and discussions.

Source: Google AI Blog


Permutation-Invariant Neural Networks for Reinforcement Learning

“The brain is able to use information coming from the skin as if it were coming from the eyes. We don’t see with the eyes or hear with the ears, these are just the receptors, seeing and hearing in fact goes on in the brain.”
Paul Bach-y-Rita1

People have the amazing ability to use one sensory modality (e.g., touch) to supply environmental information normally gathered by another sense (e.g., vision). This adaptive ability, called sensory substitution, is a phenomenon well-known to neuroscience. While difficult adaptations — such as adjusting to seeing things upside-down, learning to ride a “backwards” bicycle, or learning to “see” by interpreting visual information emitted from a grid of electrodes placed on one’s tongue — require anywhere from weeks, months or even years to attain mastery, people are able to eventually adjust to sensory substitutions.

Examples of Sensory Substitution. Left: Tongue Display Unit (Maris and Bach-y-Rita, 2001; Image: Kaczmarek, 2011). Right: “Upside down goggles” initially conceived by Erismann and Kohler in 1931. (Image Wikipedia).

In contrast, most neural networks are not able to adapt to sensory substitutions at all. For instance, most reinforcement learning (RL) agents require their inputs to be in a pre-specified format, or else they will fail. They expect fixed-size inputs and assume that each element of the input carries a precise meaning, such as the pixel intensity at a specified location, or state information, like position or velocity. In popular RL benchmark tasks (e.g., Ant or Cart-pole), an agent trained using current RL algorithms will fail if its sensory inputs are changed or if the agent is fed additional noisy inputs that are unrelated to the task at hand.

In “The Sensory Neuron as a Transformer: Permutation-Invariant Neural Networks for Reinforcement Learning”, a spotlight paper at NeurIPS 2021, we explore permutation invariant neural network agents, which require each of their sensory neurons (receptors that receive sensory inputs from the environment) to figure out the meaning and context of its input signal, rather than explicitly assuming a fixed meaning. Our experiments show that such agents are robust to observations that contain additional redundant or noisy information, and to observations that are corrupt and incomplete.

Permutation invariant reinforcement learning agents adapting to sensory substitutions. Left: The ordering of the ant’s 28 observations are randomly shuffled every 200 time-steps. Unlike the standard policy, our policy is not affected by the suddenly permuted inputs. Right: Cart-pole agent given many redundant noisy inputs (Interactive web-demo).

In addition to adapting to sensory substitutions in state-observation environments (like the ant and cart-pole examples), we show that these agents can also adapt to sensory substitutions in complex visual-observation environments (such as a CarRacing game that uses only pixel observations) and can perform when the stream of input images is constantly being reshuffled:

We partition the visual input from CarRacing into a 2D grid of small patches, and shuffled their ordering. Without any additional training, our agent still performs even when the original training background (left) is replaced with new images (right).

Method
Our approach takes observations from the environment at each time-step and feeds each element of the observation into distinct, but identical neural networks (called “sensory neurons”), each with no fixed relationship with one another. Each sensory neuron integrates over time information from only their particular sensory input channel. Because each sensory neuron receives only a small part of the full picture, they need to self-organize through communication in order for a global coherent behavior to emerge.

Illustration of observation segmentation.We segment each input into elements, which are then fed to independent sensory neurons. For non-vision tasks where the inputs are usually 1D vectors, each element is a scalar. For vision tasks, we crop each input image into non-overlapping patches.

We encourage neurons to communicate with each other by training them to broadcast messages. While receiving information locally, each individual sensory neuron also continually broadcasts an output message at each time-step. These messages are consolidated and combined into an output vector, called the global latent code, using an attention mechanism similar to that applied in the Transformer architecture. A policy network then uses the global latent code to produce the action that the agent will use to interact with the environment. This action is also fed back into each sensory neuron in the next time-step, closing the communication loop.

Overview of the permutation-invariant RL method. We first feed each individual observation (ot) into a particular sensory neuron (along with the agent’s previous action, at-1). Each neuron then produces and broadcasts a message independently, and an attention mechanism summarizes them into a global latent code (mt) that is given to the agent's downstream policy network (?) to produce the agent’s action at.

Why is this system permutation invariant? Each sensory neuron is an identical neural network that is not confined to only process information from one particular sensory input. In fact, in our setup, the inputs to each sensory neuron are not defined. Instead, each neuron must figure out the meaning of its input signal by paying attention to the inputs received by the other sensory neurons, rather than explicitly assuming a fixed meaning. This encourages the agent to process the entire input as an unordered set, making the system to be permutation invariant to its input. Furthermore, in principle, the agent can use as many sensory neurons as required, thus enabling it to process observations of arbitrary length. Both of these properties will help the agent adapt to sensory substitutions.

Results
We demonstrate the robustness and flexibility of this approach in simpler, state-observation environments, where the observations the agent receives as inputs are low-dimensional vectors holding information about the agent’s states, such as the position or velocity of its components. The agent in the popular Ant locomotion task has a total of 28 inputs with information that includes positions and velocities. We shuffle the order of the input vector several times during a trial and show that the agent is rapidly able to adapt and is still able to walk forward.

In cart-pole, the agent’s goal is to swing up a cart-pole mounted at the center of the cart and balance it upright. Normally the agent sees only five inputs, but we modify the cartpole environment to provide 15 shuffled input signals, 10 of which are pure noise, and the remainder of which are the actual observations from the environment. The agent is still able to perform the task, demonstrating the system’s capacity to work with a large number of inputs and attend only to channels it deems useful. Such flexibility may find useful applications for processing a large unspecified number of signals, most of which are noise, from ill-defined systems.

We also apply this approach to high-dimensional vision-based environments where the observation is a stream of pixel images. Here, we investigate screen-shuffled versions of vision-based RL environments, where each observation frame is divided into a grid of patches, and like a puzzle, the agent must process the patches in a shuffled order to determine a course of action to take. To demonstrate our approach on vision-based tasks, we created a shuffled version of Atari Pong.

Shuffled Pong results. Left: Pong agent trained to play using only 30% of the patches matches performance of Atari opponent. Right: Without extra training, when we give the agent more puzzle pieces, its performance increases.

Here the agent’s input is a variable-length list of patches, so unlike typical RL agents, the agent only gets to “see” a subset of patches from the screen. In the puzzle pong experiment, we pass to the agent a random sample of patches across the screen, which are then fixed through the remainder of the game. We find that we can discard 70% of the patches (at these fixed-random locations) and still train the agent to perform well against the built-in Atari opponent. Interestingly, if we then reveal additional information to the agent (e.g., allowing it access to more image patches), its performance increases, even without additional training. When the agent receives all the patches, in shuffled order, it wins 100% of the time, achieving the same result with agents that are trained while seeing the entire screen.

We find that imposing additional difficulty during training by using unordered observations has additional benefits, such as improving generalization to unseen variations of the task, like when the background of the CarRacing training environment is replaced with a novel image.

Shuffled CarRacing results. The agent has learned to focus its attention (indicated by the highlighted patches) on the road boundaries. Left: Training environment. Right: Test environment with new background.

Conclusion
The permutation invariant neural network agents presented here can handle ill-defined, varying observation spaces. Our agents are robust to observations that contain redundant or noisy information, or observations that are corrupt and incomplete. We believe that permutation invariant systems open up numerous possibilities in reinforcement learning.

If you’re interested to learn more about this work, we invite readers to read our interactive article (pdf version) or watch our video. We also released code to reproduce our experiments.



1Quoted in Livewired, by David Eagleman.  

Source: Google AI Blog