Tag Archives: deep learning

Nested Hierarchical Transformer: Towards Accurate, Data-Efficient, and Interpretable Visual Understanding

In visual understanding, the Visual Transformer (ViT) and its variants have received significant attention recently due to their superior performance on many core visual applications, such as image classification, object detection, and video understanding. The core idea of ViT is to utilize the power of self-attention layers to learn global relationships between small patches of images. However, the number of connections between patches increases quadratically with image size. Such a design has been observed to be data inefficient — although the original ViT can perform better than convolutional networks with hundreds of millions of images for pre-training, such a data requirement is not always practical, and it still underperforms compared to convolutional networks when given less data. Many are exploring to find more suitable architectural re-designs that can learn visual representations effectively, such as by adding convolutional layers and building hierarchical structures with local self-attention.

The principle of hierarchical structure is one of the core ideas in vision models, where bottom layers learn more local object structures on the high-dimensional pixel space and top layers learn more abstracted and high-level knowledge at low-dimensional feature space. Existing ViT-based methods focus on designing a variety of modifications inside self-attention layers to achieve such a hierarchy, but while these offer promising performance improvements, they often require substantial architectural re-designs. Moreover, these approaches lack an interpretable design, so it is difficult to explain the inner-workings of trained models.

To address these challenges, in “Nested Hierarchical Transformer: Towards Accurate, Data-Efficient and Interpretable Visual Understanding”, we present a rethinking of existing hierarchical structure–driven designs, and provide a novel and orthogonal approach to significantly simplify them. The central idea of this work is to decouple feature learning and feature abstraction (pooling) components: nested transformer layers encode visual knowledge of image patches separately, and then the processed information is aggregated. This process is repeated in a hierarchical manner, resulting in a pyramid network structure. The resulting architecture achieves competitive results on ImageNet and outperforms results on data-efficient benchmarks. We have shown such a design can meaningfully improve data efficiency with faster convergence and provide valuable interpretability benefits. Moreover, we introduce GradCAT, a new technique for interpreting the decision process of a trained model at inference time.

Architecture Design
The overall architecture is simple to implement by adding just a few lines of Python code to the source code of the original ViT. The original ViT architecture divides an input image into small patches, projects pixels of each patch to a vector with predefined dimension, and then feeds the sequences of all vectors to the overall ViT architecture containing multiple stacked identical transformer layers. While every layer in ViT processes information of the whole image, with this new method, stacked transformer layers are used to process only a region (i.e., block) of the image containing a few spatially adjacent image patches. This step is independent for each block and is also where feature learning occurs. Finally, a new computational layer called block aggregation then combines all of the spatially adjacent blocks. After each block aggregation, the features corresponding to four spatially adjacent blocks are fed to another module with a stack of transformer layers, which then process those four blocks jointly. This design naturally builds a pyramid hierarchical structure of the network, where bottom layers can focus on local features (such as textures) and upper layers focus on global features (such as object shape) at reduced dimensionality thanks to the block aggregation.

A visualization of the network processing an image: Given an input image, the network first partitions images into blocks, where each block contains 4 image patches. Image patches in every block are linearly projected as vectors and processed by a stack of identical transformer layers. Then the proposed block aggregation layer aggregates information from each block and reduces its spatial size by 4 times. The number of blocks is reduced to 1 at the top hierarchy and classification is conducted after the output of it.

Interpretability
This architecture has a non-overlapping information processing mechanism, independent at every node. This design resembles a decision tree-like structure, which manifests unique interpretability capabilities because every tree node contains independent information of an image block that is being received by its parent nodes. We can trace the information flow through the nodes to understand the importance of each feature. In addition, our hierarchical structure retains the spatial structure of images throughout the network, leading to learned spatial feature maps that are effective for interpretation. Below we showcase two kinds of visual interpretability.

First, we present a method to interpret the trained model on test images, called gradient-based class-aware tree-traversal (GradCAT). GradCAT traces the feature importance of each block (a tree node) from top to bottom of the hierarchy structure. The main idea is to find the most valuable traversal from the root node at the top layer to a child node at the bottom layer that contributes the most to the classification outcomes. Since each node processes information from a certain region of the image, such traversal can be easily mapped to the image space for interpretation (as shown by the overlaid dots and lines in the image below).

The following is an example of the model's top-4 predictions and corresponding interpretability results on the left input image (containing 4 animals). As shown below, GradCAT highlights the decision path along the hierarchical structure as well as the corresponding visual cues in local image regions on the images.

Given the left input image (containing four objects), the figure visualizes the interpretability results of the top-4 prediction classes. The traversal locates the model decision path along the tree and simultaneously locates the corresponding image patch (shown by the dotted line on images) that has the highest impact to the predicted target class.

Moreover, the following figures visualize results on the ImageNet validation set and show how this approach enables some intuitive observations. For instance, the example of the lighter below (upper left panel) is particularly interesting because the ground truth class — lighter/matchstick — actually defines the bottom-right matchstick object, while the most salient visual features (with the highest node values) are actually from the upper-left red light, which conceptually shares visual cues with a lighter. This can also be seen from the overlaid red lines, which indicate the image patches with the highest impact on the prediction. Thus, although the visual cue is a mistake, the output prediction is correct. In addition, the four child nodes of the wooden spoon below have similar feature importance values (see numbers visualized in the nodes; higher indicates more importance), which is because the wooden texture of the table is similar to that of the spoon.

Visualization of the results obtained by the proposed GradCAT. Images are from the ImageNet validation dataset.

Second, different from the original ViT, our hierarchical architecture retains spatial relationships in learned representations. The top layers output low-resolution features maps of input images, enabling the model to easily perform attention-based interpretation by applying Class Attention Map (CAM) on the learned representations at the top hierarchical level. This enables high-quality weakly-supervised object localization with just image-level labels. See the following figure for examples.

Visualization of CAM-based attention results. Warmer colors indicate higher attention. Images are from the ImageNet validation dataset.

Convergence Advantages
With this design, feature learning only happens at local regions independently, and feature abstraction happens inside the aggregation function. This design and simple implementation is general enough for other types of visual understanding tasks beyond classification. It also improves the model convergence speed greatly, significantly reducing the training time to reach the desired maximum accuracy.

We validate this advantage in two ways. First, we compare the ViT structure on the ImageNet accuracy with a different number of total training epochs. The results are shown on the left side of the figure below, demonstrating much faster convergence than the original ViT, e.g., around 20% improvement in accuracy over ViT with 30 total training epochs.

Second, we modify the architecture to conduct unconditional image generation tasks, since training ViT-based models for image generation tasks is challenging due to convergence and speed issues. Creating such a generator is straightforward by transposing the proposed architecture: the input is an embedding vector, the output is a full image in RGB channels, and the block aggregation is replaced by a block de-aggregation component supported by Pixel Shuffling. Surprisingly, we find our generator is easy to train and demonstrates faster convergence speed, as well as better FID score (which measures how similar generated images are to real ones), than the capacity-comparable SAGAN.

Left: ImageNet accuracy given different number of total training epochs compared with standard ViT architecture. Right: ImageNet 64x64 image generation FID scores (lower is better) with single 1000-epoch training. On both tasks, our method shows better convergence speed.

Conclusion
In this work we demonstrate the simple idea that decoupled feature learning and feature information extraction in this nested hierarchy design leads to better feature interpretability through a new gradient-based class-aware tree traversal method. Moreover, the architecture improves convergence on not only classification tasks but also image generation tasks. The proposed idea is focusing on aggregation function and thereby is orthogonal to advanced architecture design for self-attention. We hope this new research encourages future architecture designers to explore more interpretable and data-efficient ViT-based models for visual understanding, like the adoption of this work for high-resolution image generation. We have also released the source code for the image classification portion of this work.

Acknowledgements
We gratefully acknowledge the contributions of other co-authors, including Han Zhang, Long Zhao, Ting Chen, Sercan Arik, Tomas Pfister. We also thank Xiaohua Zhai, Jeremy Kubica, Kihyuk Sohn, and Madeleine Udell for the valuable feedback of the work.

Source: Google AI Blog


Nested Hierarchical Transformer: Towards Accurate, Data-Efficient, and Interpretable Visual Understanding

In visual understanding, the Visual Transformer (ViT) and its variants have received significant attention recently due to their superior performance on many core visual applications, such as image classification, object detection, and video understanding. The core idea of ViT is to utilize the power of self-attention layers to learn global relationships between small patches of images. However, the number of connections between patches increases quadratically with image size. Such a design has been observed to be data inefficient — although the original ViT can perform better than convolutional networks with hundreds of millions of images for pre-training, such a data requirement is not always practical, and it still underperforms compared to convolutional networks when given less data. Many are exploring to find more suitable architectural re-designs that can learn visual representations effectively, such as by adding convolutional layers and building hierarchical structures with local self-attention.

The principle of hierarchical structure is one of the core ideas in vision models, where bottom layers learn more local object structures on the high-dimensional pixel space and top layers learn more abstracted and high-level knowledge at low-dimensional feature space. Existing ViT-based methods focus on designing a variety of modifications inside self-attention layers to achieve such a hierarchy, but while these offer promising performance improvements, they often require substantial architectural re-designs. Moreover, these approaches lack an interpretable design, so it is difficult to explain the inner-workings of trained models.

To address these challenges, in “Nested Hierarchical Transformer: Towards Accurate, Data-Efficient and Interpretable Visual Understanding”, we present a rethinking of existing hierarchical structure–driven designs, and provide a novel and orthogonal approach to significantly simplify them. The central idea of this work is to decouple feature learning and feature abstraction (pooling) components: nested transformer layers encode visual knowledge of image patches separately, and then the processed information is aggregated. This process is repeated in a hierarchical manner, resulting in a pyramid network structure. The resulting architecture achieves competitive results on ImageNet and outperforms results on data-efficient benchmarks. We have shown such a design can meaningfully improve data efficiency with faster convergence and provide valuable interpretability benefits. Moreover, we introduce GradCAT, a new technique for interpreting the decision process of a trained model at inference time.

Architecture Design
The overall architecture is simple to implement by adding just a few lines of Python code to the source code of the original ViT. The original ViT architecture divides an input image into small patches, projects pixels of each patch to a vector with predefined dimension, and then feeds the sequences of all vectors to the overall ViT architecture containing multiple stacked identical transformer layers. While every layer in ViT processes information of the whole image, with this new method, stacked transformer layers are used to process only a region (i.e., block) of the image containing a few spatially adjacent image patches. This step is independent for each block and is also where feature learning occurs. Finally, a new computational layer called block aggregation then combines all of the spatially adjacent blocks. After each block aggregation, the features corresponding to four spatially adjacent blocks are fed to another module with a stack of transformer layers, which then process those four blocks jointly. This design naturally builds a pyramid hierarchical structure of the network, where bottom layers can focus on local features (such as textures) and upper layers focus on global features (such as object shape) at reduced dimensionality thanks to the block aggregation.

A visualization of the network processing an image: Given an input image, the network first partitions images into blocks, where each block contains 4 image patches. Image patches in every block are linearly projected as vectors and processed by a stack of identical transformer layers. Then the proposed block aggregation layer aggregates information from each block and reduces its spatial size by 4 times. The number of blocks is reduced to 1 at the top hierarchy and classification is conducted after the output of it.

Interpretability
This architecture has a non-overlapping information processing mechanism, independent at every node. This design resembles a decision tree-like structure, which manifests unique interpretability capabilities because every tree node contains independent information of an image block that is being received by its parent nodes. We can trace the information flow through the nodes to understand the importance of each feature. In addition, our hierarchical structure retains the spatial structure of images throughout the network, leading to learned spatial feature maps that are effective for interpretation. Below we showcase two kinds of visual interpretability.

First, we present a method to interpret the trained model on test images, called gradient-based class-aware tree-traversal (GradCAT). GradCAT traces the feature importance of each block (a tree node) from top to bottom of the hierarchy structure. The main idea is to find the most valuable traversal from the root node at the top layer to a child node at the bottom layer that contributes the most to the classification outcomes. Since each node processes information from a certain region of the image, such traversal can be easily mapped to the image space for interpretation (as shown by the overlaid dots and lines in the image below).

The following is an example of the model's top-4 predictions and corresponding interpretability results on the left input image (containing 4 animals). As shown below, GradCAT highlights the decision path along the hierarchical structure as well as the corresponding visual cues in local image regions on the images.

Given the left input image (containing four objects), the figure visualizes the interpretability results of the top-4 prediction classes. The traversal locates the model decision path along the tree and simultaneously locates the corresponding image patch (shown by the dotted line on images) that has the highest impact to the predicted target class.

Moreover, the following figures visualize results on the ImageNet validation set and show how this approach enables some intuitive observations. For instance, the example of the lighter below (upper left panel) is particularly interesting because the ground truth class — lighter/matchstick — actually defines the bottom-right matchstick object, while the most salient visual features (with the highest node values) are actually from the upper-left red light, which conceptually shares visual cues with a lighter. This can also be seen from the overlaid red lines, which indicate the image patches with the highest impact on the prediction. Thus, although the visual cue is a mistake, the output prediction is correct. In addition, the four child nodes of the wooden spoon below have similar feature importance values (see numbers visualized in the nodes; higher indicates more importance), which is because the wooden texture of the table is similar to that of the spoon.

Visualization of the results obtained by the proposed GradCAT. Images are from the ImageNet validation dataset.

Second, different from the original ViT, our hierarchical architecture retains spatial relationships in learned representations. The top layers output low-resolution features maps of input images, enabling the model to easily perform attention-based interpretation by applying Class Attention Map (CAM) on the learned representations at the top hierarchical level. This enables high-quality weakly-supervised object localization with just image-level labels. See the following figure for examples.

Visualization of CAM-based attention results. Warmer colors indicate higher attention. Images are from the ImageNet validation dataset.

Convergence Advantages
With this design, feature learning only happens at local regions independently, and feature abstraction happens inside the aggregation function. This design and simple implementation is general enough for other types of visual understanding tasks beyond classification. It also improves the model convergence speed greatly, significantly reducing the training time to reach the desired maximum accuracy.

We validate this advantage in two ways. First, we compare the ViT structure on the ImageNet accuracy with a different number of total training epochs. The results are shown on the left side of the figure below, demonstrating much faster convergence than the original ViT, e.g., around 20% improvement in accuracy over ViT with 30 total training epochs.

Second, we modify the architecture to conduct unconditional image generation tasks, since training ViT-based models for image generation tasks is challenging due to convergence and speed issues. Creating such a generator is straightforward by transposing the proposed architecture: the input is an embedding vector, the output is a full image in RGB channels, and the block aggregation is replaced by a block de-aggregation component supported by Pixel Shuffling. Surprisingly, we find our generator is easy to train and demonstrates faster convergence speed, as well as better FID score (which measures how similar generated images are to real ones), than the capacity-comparable SAGAN.

Left: ImageNet accuracy given different number of total training epochs compared with standard ViT architecture. Right: ImageNet 64x64 image generation FID scores (lower is better) with single 1000-epoch training. On both tasks, our method shows better convergence speed.

Conclusion
In this work we demonstrate the simple idea that decoupled feature learning and feature information extraction in this nested hierarchy design leads to better feature interpretability through a new gradient-based class-aware tree traversal method. Moreover, the architecture improves convergence on not only classification tasks but also image generation tasks. The proposed idea is focusing on aggregation function and thereby is orthogonal to advanced architecture design for self-attention. We hope this new research encourages future architecture designers to explore more interpretable and data-efficient ViT-based models for visual understanding, like the adoption of this work for high-resolution image generation. We have also released the source code for the image classification portion of this work.

Acknowledgements
We gratefully acknowledge the contributions of other co-authors, including Han Zhang, Long Zhao, Ting Chen, Sercan Arik, Tomas Pfister. We also thank Xiaohua Zhai, Jeremy Kubica, Kihyuk Sohn, and Madeleine Udell for the valuable feedback of the work.

Source: Google AI Blog


Robot See, Robot Do

People learn to do things by watching others — from mimicking new dance moves, to watching YouTube cooking videos. We’d like robots to do the same, i.e., to learn new skills by watching people do things during training. Today, however, the predominant paradigm for teaching robots is to remote control them using specialized hardware for teleoperation and then train them to imitate pre-recorded demonstrations. This limits both who can provide the demonstrations (programmers & roboticists) and where they can be provided (lab settings). If robots could instead self-learn new tasks by watching humans, this capability could allow them to be deployed in more unstructured settings like the home, and make it dramatically easier for anyone to teach or communicate with them, expert or otherwise. Perhaps one day, they might even be able to use Youtube videos to grow their collection of skills over time.

Our motivation is to have robots watch people do tasks, naturally with their hands, and then use that data as demonstrations for learning. Video by Teh Aik Hui and Nathaniel Lim. License: CC-BY

However, an obvious but often overlooked problem is that a robot is physically different from a human, which means it often completes tasks differently than we do. For example, in the pen manipulation task below, the hand can grab all the pens together and quickly transfer them between containers, whereas the two-fingered gripper must transport one at a time. Prior research assumes that humans and robots can do the same task similarly, which makes manually specifying one-to-one correspondences between human and robot actions easy. But with stark differences in physique, defining such correspondences for seemingly easy tasks can be surprisingly difficult and sometimes impossible.

Physically different end-effectors (i.e., “grippers”) (i.e., the part that interacts with the environment) induce different control strategies when solving the same task. Left: The hand grabs all pens and quickly transfers them between containers. Right: The two-fingered gripper transports one pen at a time.

In “XIRL: Cross-Embodiment Inverse RL”, presented as an oral paper at CoRL 2021, we explore these challenges further and introduce a self-supervised method for Cross-embodiment Inverse Reinforcement Learning (XIRL). Rather than focusing on how individual human actions should correspond to robot actions, XIRL learns the high-level task objective from videos, and summarizes that knowledge in the form of a reward function that is invariant to embodiment differences, such as shape, actions and end-effector dynamics. The learned rewards can then be used together with reinforcement learning to teach the task to agents with new physical embodiments through trial and error. Our approach is general and scales autonomously with data — the more embodiment diversity presented in the videos, the more invariant and robust the reward functions become. Experiments show that our learned reward functions lead to significantly more sample efficient (roughly 2 to 4 times) reinforcement learning on new embodiments compared to alternative methods. To extend and build on our work, we are releasing an accompanying open-source implementation of our method along with X-MAGICAL, our new simulated benchmark for cross-embodiment imitation.

Cross-Embodiment Inverse Reinforcement Learning (XIRL)
The underlying observation in this work is that in spite of the many differences induced by different embodiments, there still exist visual cues that reflect progression towards a common task objective. For example, in the pen manipulation task above, the presence of pens in the cup but not the mug, or the absence of pens on the table, are key frames that are common to different embodiments and indirectly provide cues for how close to being complete a task is. The key idea behind XIRL is to automatically discover these key moments in videos of different length and cluster them meaningfully to encode task progression. This motivation shares many similarities with unsupervised video alignment research, from which we can leverage a method called Temporal Cycle Consistency (TCC), which aligns videos accurately while learning useful visual representations for fine-grained video understanding without requiring any ground-truth correspondences.

We leverage TCC to train an encoder to temporally align video demonstrations of different experts performing the same task. The TCC loss tries to maximize the number of cycle-consistent frames (or mutual nearest-neighbors) between pairs of sequences using a differentiable formulation of soft nearest-neighbors. Once the encoder is trained, we define our reward function as simply the negative Euclidean distance between the current observation and the goal observation in the learned embedding space. We can subsequently insert the reward into a standard MDP and use an RL algorithm to learn the demonstrated behavior. Surprisingly, we find that this simple reward formulation is effective for cross-embodiment imitation.

XIRL self-supervises reward functions from expert demonstrations using temporal cycle consistency (TCC), then uses them for downstream reinforcement learning to learn new skills from third-person demonstrations.

X-MAGICAL Benchmark
To evaluate the performance of XIRL and baseline alternatives (e.g., TCN, LIFS, Goal Classifier) in a consistent environment, we created X-MAGICAL, which is a simulated benchmark for cross-embodiment imitation. X-MAGICAL features a diverse set of agent embodiments, with differences in their shapes and end-effectors, designed to solve tasks in different ways. This leads to differences in execution speeds and state-action trajectories, which poses challenges for current imitation learning techniques, e.g., ones that use time as a heuristic for weak correspondences between two trajectories. The ability to generalize across embodiments is precisely what X-MAGICAL evaluates.

The SweepToTop task we considered for our experiments is a simplified 2D equivalent of a common household robotic sweeping task, where an agent has to push three objects into a goal zone in the environment. We chose this task specifically because its long-horizon nature highlights how different agent embodiments can generate entirely different trajectories (shown below). X-MAGICAL features a Gym API and is designed to be easily extendable to new tasks and embodiments. You can try it out today with pip install x-magical.

Different agent shapes in the SweepToTop task in the X-MAGICAL benchmark need to use different strategies to reposition objects into the target area (pink), i.e., to “clear the debris”. For example, the long-stick can clear them all in one fell swoop, whereas the short-stick needs to do multiple consecutive back-and-forths.
Left: Heatmap of state visitation for each embodiment across all expert demonstrations. Right: Examples of expert trajectories for each embodiment.

Highlights
In our first set of experiments, we checked whether our learned embodiment-invariant reward function can enable successful reinforcement learning, when the expert demonstrations are provided through the agent itself. We find that XIRL significantly outperforms alternative methods especially on the tougher agents (e.g., short-stick and gripper).

Same-embodiment setting: Comparison of XIRL with baseline reward functions, using SAC for RL policy learning. XIRL is roughly 2 to 4 times more sample efficient than some of the baselines on the harder agents (short-stick and gripper).

We also find that our approach shows great potential for learning reward functions that generalize to novel embodiments. For instance, when reward learning is performed on embodiments that are different from the ones on which the policy is trained, we find that it results in significantly more sample efficient agents compared to the same baselines. Below, in the gripper subplot (bottom right) for example, the reward is first learned on demonstration videos from long-stick, medium-stick and short-stick, after which the reward function is used to train the gripper agent.

Cross-embodiment setting: XIRL performs favorably when compared with other baseline reward functions, trained on observation-only demonstrations from different embodiments. Each agent (long-stick, medium-stick, short-stick, gripper) had its reward trained using demonstrations from the other three embodiments.

We also find that we can train on real-world human demonstrations, and use the learned reward to train a Sawyer arm in simulation to push a puck to a designated target zone. In these experiments as well, our method outperforms baseline alternatives. For example, our XIRL variant trained only on the real-world demonstrations (purple in the plots below) reaches 80% of the total performance roughly 85% faster than the RLV baseline (orange).

What Do The Learned Reward Functions Look Like?
To further explore the qualitative nature of our learned rewards in more challenging real-world scenarios, we collect a dataset of the pen transfer task using various household tools.

Below, we show rewards extracted from a successful (top) and unsuccessful (bottom) demonstration. Both demonstrations follow a similar trajectory at the start of the task execution. The successful one nets a high reward for placing the pens consecutively into the mug then into the glass cup, while the unsuccessful one obtains a low reward because it drops the pens outside the glass cup towards the end of the execution (orange circle). These results are promising because they show that our learned encoder can represent fine-grained visual differences relevant to a task.

Conclusion
We highlighted XIRL, our approach to tackling the cross-embodiment imitation problem. XIRL learns an embodiment-invariant reward function that encodes task progress using a temporal cycle-consistency objective. Policies learned using our reward functions are significantly more sample-efficient than baseline alternatives. Furthermore, the reward functions do not require manually paired video frames between the demonstrator and the learner, giving them the ability to scale to an arbitrary number of embodiments or experts with varying skill levels. Overall, we are excited about this direction of work, and hope that our benchmark promotes further research in this area. For more details, please check out our paper and download the code from our GitHub repository.

Acknowledgments
Kevin and Andy summarized research performed together with Pete Florence, Jonathan Tompson, Jeannette Bohg (faculty at Stanford University) and Debidatta Dwibedi. All authors would additionally like to thank Alex Nichol, Nick Hynes, Sean Kirmani, Brent Yi, Jimmy Wu, Karl Schmeckpeper and Minttu Alakuijala for fruitful technical discussions, and Sam Toyer for invaluable help with setting up the simulated benchmark.

Source: Google AI Blog


Applying Differential Privacy to Large Scale Image Classification

Machine learning (ML) models are becoming increasingly valuable for improved performance across a variety of consumer products, from recommendations to automatic image classification. However, despite aggregating large amounts of data, in theory it is possible for models to encode characteristics of individual entries from the training set. For example, experiments in controlled settings have shown that language models trained using email datasets may sometimes encode sensitive information included in the training data and may have the potential to reveal the presence of a particular user’s data in the training set. As such, it is important to prevent the encoding of such characteristics from individual training entries. To these ends, researchers are increasingly employing federated learning approaches.

Differential privacy (DP) provides a rigorous mathematical framework that allows researchers to quantify and understand the privacy guarantees of a system or an algorithm. Within the DP framework, privacy guarantees of a system are usually characterized by a positive parameter ε, called the privacy loss bound, with smaller ε corresponding to better privacy. One usually trains a model with DP guarantees using DP-SGD, a specialized training algorithm that provides DP guarantees for the trained model.

However training with DP-SGD typically has two major drawbacks. First, most existing implementations of DP-SGD are inefficient and slow, which makes it hard to use on large datasets. Second, DP-SGD training often significantly impacts utility (such as model accuracy) to the point that models trained with DP-SGD may become unusable in practice. As a result most DP research papers evaluate DP algorithms on very small datasets (MNIST, CIFAR-10, or UCI) and don’t even try to perform evaluation of larger datasets, such as ImageNet.

In “Toward Training at ImageNet Scale with Differential Privacy”, we share initial results from our ongoing effort to train a large image classification model on ImageNet using DP while maintaining high accuracy and minimizing computational cost. We show that the combination of various training techniques, such as careful choice of the model and hyperparameters, large batch training, and transfer learning from other datasets, can significantly boost accuracy of an ImageNet model trained with DP. To substantiate these discoveries and encourage follow-up research, we are also releasing the associated source code.

Testing Differential Privacy on ImageNet
We choose ImageNet classification as a demonstration of the practicality and efficacy of DP because: (1) it is an ambitious task for DP, for which no prior work shows sufficient progress; and (2) it is a public dataset on which other researchers can operate, so it represents an opportunity to collectively improve the utility of real-life DP training. Classification on ImageNet is challenging for DP because it requires large networks with many parameters. This translates into a significant amount of noise added into the computation, because the noise added scales with the size of the model.

Scaling Differential Privacy with JAX
Exploring multiple architectures and training configurations to research what works for DP can be debilitatingly slow. To streamline our efforts, we used JAX, a high-performance computational library based on XLA that can do efficient auto-vectorization and just-in-time compilation of the mathematical computations. Using these JAX features was previously recommended as a good way to speed up DP-SGD in the context of smaller datasets such as CIFAR-10.

We created our own implementation of DP-SGD on JAX and benchmarked it against the large ImageNet dataset (the code is included in our release). The implementation in JAX was relatively simple and resulted in noticeable performance gains simply because of using the XLA compiler. Compared to other implementations of DP-SGD, such as that in Tensorflow Privacy, the JAX implementation is consistently several times faster. It is typically even faster compared to the custom-built and optimized PyTorch Opacus.

Each step of our DP-SGD implementation takes approximately two forward-backward passes through the network. While this is slower than non-private training, which requires only a single forward-backward pass, it is still the most efficient known approach to train with the per-example gradients necessary for DP-SGD. The graph below shows training runtimes for two models on ImageNet with DP-SGD vs. non-private SGD, each on JAX. Overall, we find DP-SGD on JAX sufficiently fast to run large experiments just by slightly reducing the number of training runs used to find optimal hyperparameters compared to non-private training. This is significantly better than alternatives, such as Tensorflow Privacy, which we found to be ~5x–10x slower on our CIFAR10 and MNIST benchmarks.

Time in seconds per training epoch on ImageNet using a Resnet18 or Resnet50 architecture with 8 V100 GPUs.

Combining Techniques for Improved Accuracy
It is possible that future training algorithms may improve DP’s privacy-utility tradeoff. However, with current algorithms, such as DP-SGD, our experience points to an engineering “bag-of-tricks” approach to make DP more practical on challenging tasks like ImageNet.

Because we can train models faster with JAX, we can iterate quickly and explore multiple configurations to find what works well for DP. We report the following combination of techniques as useful to achieve non-trivial accuracy and privacy on ImageNet:

  • Full-batch training

    Theoretically, it is known that larger minibatch sizes improve the utility of DP-SGD, with full-batch training (i.e., where a full dataset is one batch) giving the best utility [1, 2], and empirical results are emerging to support this theory. Indeed, our experiments demonstrate that increasing the batch size along with the number of training epochs leads to a decrease in ε while still maintaining accuracy. However, training with extremely large batches is non-trivial as the batch cannot fit into GPU/TPU memory. So, we employed virtual large-batch training by accumulating gradients for multiple steps before updating the weights instead of applying gradient updates on each training step.

    Batch size 1024 4 × 1024 16 × 1024 64 × 1024
    Number of epochs 10 40 160 640
    Accuracy 56% 57.5% 57.9% 57.2%
    Privacy loss bound ε 9.8 × 108 6.1 × 107 3.5 × 106 6.7 × 104

  • Transfer learning from public data

    Pre-training on public data followed by DP fine-tuning on private data has previously been shown to improve accuracy on other benchmarks [3, 4]. A question that remains is what public data to use for a given task to optimize transfer learning. In this work we simulate a private/public data split by using ImageNet as "private" data and using Places365, another image classification dataset, as a proxy for “public" data. We pre-trained our models on Places365 before fine-tuning them with DP-SGD on ImageNet. Places365 only has images of landscapes and buildings, not of animals as ImageNet, so it is quite different, making it a good candidate to demonstrate the ability of the model to transfer to a different but related domain.

    We found that transfer learning from Places365 gave us 47.5% accuracy on ImageNet with a reasonable level of privacy (ε = 10). This is low compared to the 70% accuracy of a similar non-private model, but compared to naïve DP training on ImageNet, which yields either very low accuracy (2 - 5%) or no privacy (ε=109), this is quite good.

Privacy-accuracy tradeoff for Resnet-18 on ImageNet using large-batch training with transfer learning from Places365.

Next Steps
We hope these early results and source code provide an impetus for other researchers to work on improving DP for ambitious tasks such as ImageNet as a proxy for challenging production-scale tasks. With the much faster DP-SGD on JAX, we urge DP and ML researchers to explore diverse training regimes, model architectures, and algorithms to make DP more practical. To continue advancing the state of the field, we recommend researchers start with a baseline that incorporates full-batch training plus transfer learning.

Acknowledgments
This work was carried out with the support of the Google Visiting Researcher Program while Prof. Geambasu, an Associate Professor with Columbia University, was on sabbatical with Google Research. This work received substantial contributions from Steve Chien, Shuang Song, Andreas Terzis and Abhradeep Guha Thakurta.

Source: Google AI Blog


Applying Differential Privacy to Large Scale Image Classification

Machine learning (ML) models are becoming increasingly valuable for improved performance across a variety of consumer products, from recommendations to automatic image classification. However, despite aggregating large amounts of data, in theory it is possible for models to encode characteristics of individual entries from the training set. For example, experiments in controlled settings have shown that language models trained using email datasets may sometimes encode sensitive information included in the training data and may have the potential to reveal the presence of a particular user’s data in the training set. As such, it is important to prevent the encoding of such characteristics from individual training entries. To these ends, researchers are increasingly employing federated learning approaches.

Differential privacy (DP) provides a rigorous mathematical framework that allows researchers to quantify and understand the privacy guarantees of a system or an algorithm. Within the DP framework, privacy guarantees of a system are usually characterized by a positive parameter ε, called the privacy loss bound, with smaller ε corresponding to better privacy. One usually trains a model with DP guarantees using DP-SGD, a specialized training algorithm that provides DP guarantees for the trained model.

However training with DP-SGD typically has two major drawbacks. First, most existing implementations of DP-SGD are inefficient and slow, which makes it hard to use on large datasets. Second, DP-SGD training often significantly impacts utility (such as model accuracy) to the point that models trained with DP-SGD may become unusable in practice. As a result most DP research papers evaluate DP algorithms on very small datasets (MNIST, CIFAR-10, or UCI) and don’t even try to perform evaluation of larger datasets, such as ImageNet.

In “Toward Training at ImageNet Scale with Differential Privacy”, we share initial results from our ongoing effort to train a large image classification model on ImageNet using DP while maintaining high accuracy and minimizing computational cost. We show that the combination of various training techniques, such as careful choice of the model and hyperparameters, large batch training, and transfer learning from other datasets, can significantly boost accuracy of an ImageNet model trained with DP. To substantiate these discoveries and encourage follow-up research, we are also releasing the associated source code.

Testing Differential Privacy on ImageNet
We choose ImageNet classification as a demonstration of the practicality and efficacy of DP because: (1) it is an ambitious task for DP, for which no prior work shows sufficient progress; and (2) it is a public dataset on which other researchers can operate, so it represents an opportunity to collectively improve the utility of real-life DP training. Classification on ImageNet is challenging for DP because it requires large networks with many parameters. This translates into a significant amount of noise added into the computation, because the noise added scales with the size of the model.

Scaling Differential Privacy with JAX
Exploring multiple architectures and training configurations to research what works for DP can be debilitatingly slow. To streamline our efforts, we used JAX, a high-performance computational library based on XLA that can do efficient auto-vectorization and just-in-time compilation of the mathematical computations. Using these JAX features was previously recommended as a good way to speed up DP-SGD in the context of smaller datasets such as CIFAR-10.

We created our own implementation of DP-SGD on JAX and benchmarked it against the large ImageNet dataset (the code is included in our release). The implementation in JAX was relatively simple and resulted in noticeable performance gains simply because of using the XLA compiler. Compared to other implementations of DP-SGD, such as that in Tensorflow Privacy, the JAX implementation is consistently several times faster. It is typically even faster compared to the custom-built and optimized PyTorch Opacus.

Each step of our DP-SGD implementation takes approximately two forward-backward passes through the network. While this is slower than non-private training, which requires only a single forward-backward pass, it is still the most efficient known approach to train with the per-example gradients necessary for DP-SGD. The graph below shows training runtimes for two models on ImageNet with DP-SGD vs. non-private SGD, each on JAX. Overall, we find DP-SGD on JAX sufficiently fast to run large experiments just by slightly reducing the number of training runs used to find optimal hyperparameters compared to non-private training. This is significantly better than alternatives, such as Tensorflow Privacy, which we found to be ~5x–10x slower on our CIFAR10 and MNIST benchmarks.

Time in seconds per training epoch on ImageNet using a Resnet18 or Resnet50 architecture with 8 V100 GPUs.

Combining Techniques for Improved Accuracy
It is possible that future training algorithms may improve DP’s privacy-utility tradeoff. However, with current algorithms, such as DP-SGD, our experience points to an engineering “bag-of-tricks” approach to make DP more practical on challenging tasks like ImageNet.

Because we can train models faster with JAX, we can iterate quickly and explore multiple configurations to find what works well for DP. We report the following combination of techniques as useful to achieve non-trivial accuracy and privacy on ImageNet:

  • Full-batch training

    Theoretically, it is known that larger minibatch sizes improve the utility of DP-SGD, with full-batch training (i.e., where a full dataset is one batch) giving the best utility [1, 2], and empirical results are emerging to support this theory. Indeed, our experiments demonstrate that increasing the batch size along with the number of training epochs leads to a decrease in ε while still maintaining accuracy. However, training with extremely large batches is non-trivial as the batch cannot fit into GPU/TPU memory. So, we employed virtual large-batch training by accumulating gradients for multiple steps before updating the weights instead of applying gradient updates on each training step.

    Batch size 1024 4 × 1024 16 × 1024 64 × 1024
    Number of epochs 10 40 160 640
    Accuracy 56% 57.5% 57.9% 57.2%
    Privacy loss bound ε 9.8 × 108 6.1 × 107 3.5 × 106 6.7 × 104

  • Transfer learning from public data

    Pre-training on public data followed by DP fine-tuning on private data has previously been shown to improve accuracy on other benchmarks [3, 4]. A question that remains is what public data to use for a given task to optimize transfer learning. In this work we simulate a private/public data split by using ImageNet as "private" data and using Places365, another image classification dataset, as a proxy for “public" data. We pre-trained our models on Places365 before fine-tuning them with DP-SGD on ImageNet. Places365 only has images of landscapes and buildings, not of animals as ImageNet, so it is quite different, making it a good candidate to demonstrate the ability of the model to transfer to a different but related domain.

    We found that transfer learning from Places365 gave us 47.5% accuracy on ImageNet with a reasonable level of privacy (ε = 10). This is low compared to the 70% accuracy of a similar non-private model, but compared to naïve DP training on ImageNet, which yields either very low accuracy (2 - 5%) or no privacy (ε=109), this is quite good.

Privacy-accuracy tradeoff for Resnet-18 on ImageNet using large-batch training with transfer learning from Places365.

Next Steps
We hope these early results and source code provide an impetus for other researchers to work on improving DP for ambitious tasks such as ImageNet as a proxy for challenging production-scale tasks. With the much faster DP-SGD on JAX, we urge DP and ML researchers to explore diverse training regimes, model architectures, and algorithms to make DP more practical. To continue advancing the state of the field, we recommend researchers start with a baseline that incorporates full-batch training plus transfer learning.

Acknowledgments
This work was carried out with the support of the Google Visiting Researcher Program while Prof. Geambasu, an Associate Professor with Columbia University, was on sabbatical with Google Research. This work received substantial contributions from Steve Chien, Shuang Song, Andreas Terzis and Abhradeep Guha Thakurta.

Source: Google AI Blog


Applying Differential Privacy to Large Scale Image Classification

Machine learning (ML) models are becoming increasingly valuable for improved performance across a variety of consumer products, from recommendations to automatic image classification. However, despite aggregating large amounts of data, in theory it is possible for models to encode characteristics of individual entries from the training set. For example, experiments in controlled settings have shown that language models trained using email datasets may sometimes encode sensitive information included in the training data and may have the potential to reveal the presence of a particular user’s data in the training set. As such, it is important to prevent the encoding of such characteristics from individual training entries. To these ends, researchers are increasingly employing federated learning approaches.

Differential privacy (DP) provides a rigorous mathematical framework that allows researchers to quantify and understand the privacy guarantees of a system or an algorithm. Within the DP framework, privacy guarantees of a system are usually characterized by a positive parameter ε, called the privacy loss bound, with smaller ε corresponding to better privacy. One usually trains a model with DP guarantees using DP-SGD, a specialized training algorithm that provides DP guarantees for the trained model.

However training with DP-SGD typically has two major drawbacks. First, most existing implementations of DP-SGD are inefficient and slow, which makes it hard to use on large datasets. Second, DP-SGD training often significantly impacts utility (such as model accuracy) to the point that models trained with DP-SGD may become unusable in practice. As a result most DP research papers evaluate DP algorithms on very small datasets (MNIST, CIFAR-10, or UCI) and don’t even try to perform evaluation of larger datasets, such as ImageNet.

In “Toward Training at ImageNet Scale with Differential Privacy”, we share initial results from our ongoing effort to train a large image classification model on ImageNet using DP while maintaining high accuracy and minimizing computational cost. We show that the combination of various training techniques, such as careful choice of the model and hyperparameters, large batch training, and transfer learning from other datasets, can significantly boost accuracy of an ImageNet model trained with DP. To substantiate these discoveries and encourage follow-up research, we are also releasing the associated source code.

Testing Differential Privacy on ImageNet
We choose ImageNet classification as a demonstration of the practicality and efficacy of DP because: (1) it is an ambitious task for DP, for which no prior work shows sufficient progress; and (2) it is a public dataset on which other researchers can operate, so it represents an opportunity to collectively improve the utility of real-life DP training. Classification on ImageNet is challenging for DP because it requires large networks with many parameters. This translates into a significant amount of noise added into the computation, because the noise added scales with the size of the model.

Scaling Differential Privacy with JAX
Exploring multiple architectures and training configurations to research what works for DP can be debilitatingly slow. To streamline our efforts, we used JAX, a high-performance computational library based on XLA that can do efficient auto-vectorization and just-in-time compilation of the mathematical computations. Using these JAX features was previously recommended as a good way to speed up DP-SGD in the context of smaller datasets such as CIFAR-10.

We created our own implementation of DP-SGD on JAX and benchmarked it against the large ImageNet dataset (the code is included in our release). The implementation in JAX was relatively simple and resulted in noticeable performance gains simply because of using the XLA compiler. Compared to other implementations of DP-SGD, such as that in Tensorflow Privacy, the JAX implementation is consistently several times faster. It is typically even faster compared to the custom-built and optimized PyTorch Opacus.

Each step of our DP-SGD implementation takes approximately two forward-backward passes through the network. While this is slower than non-private training, which requires only a single forward-backward pass, it is still the most efficient known approach to train with the per-example gradients necessary for DP-SGD. The graph below shows training runtimes for two models on ImageNet with DP-SGD vs. non-private SGD, each on JAX. Overall, we find DP-SGD on JAX sufficiently fast to run large experiments just by slightly reducing the number of training runs used to find optimal hyperparameters compared to non-private training. This is significantly better than alternatives, such as Tensorflow Privacy, which we found to be ~5x–10x slower on our CIFAR10 and MNIST benchmarks.

Time in seconds per training epoch on ImageNet using a Resnet18 or Resnet50 architecture with 8 V100 GPUs.

Combining Techniques for Improved Accuracy
It is possible that future training algorithms may improve DP’s privacy-utility tradeoff. However, with current algorithms, such as DP-SGD, our experience points to an engineering “bag-of-tricks” approach to make DP more practical on challenging tasks like ImageNet.

Because we can train models faster with JAX, we can iterate quickly and explore multiple configurations to find what works well for DP. We report the following combination of techniques as useful to achieve non-trivial accuracy and privacy on ImageNet:

  • Full-batch training

    Theoretically, it is known that larger minibatch sizes improve the utility of DP-SGD, with full-batch training (i.e., where a full dataset is one batch) giving the best utility [1, 2], and empirical results are emerging to support this theory. Indeed, our experiments demonstrate that increasing the batch size along with the number of training epochs leads to a decrease in ε while still maintaining accuracy. However, training with extremely large batches is non-trivial as the batch cannot fit into GPU/TPU memory. So, we employed virtual large-batch training by accumulating gradients for multiple steps before updating the weights instead of applying gradient updates on each training step.

    Batch size 1024 4 × 1024 16 × 1024 64 × 1024
    Number of epochs 10 40 160 640
    Accuracy 56% 57.5% 57.9% 57.2%
    Privacy loss bound ε 9.8 × 108 6.1 × 107 3.5 × 106 6.7 × 104

  • Transfer learning from public data

    Pre-training on public data followed by DP fine-tuning on private data has previously been shown to improve accuracy on other benchmarks [3, 4]. A question that remains is what public data to use for a given task to optimize transfer learning. In this work we simulate a private/public data split by using ImageNet as "private" data and using Places365, another image classification dataset, as a proxy for “public" data. We pre-trained our models on Places365 before fine-tuning them with DP-SGD on ImageNet. Places365 only has images of landscapes and buildings, not of animals as ImageNet, so it is quite different, making it a good candidate to demonstrate the ability of the model to transfer to a different but related domain.

    We found that transfer learning from Places365 gave us 47.5% accuracy on ImageNet with a reasonable level of privacy (ε = 10). This is low compared to the 70% accuracy of a similar non-private model, but compared to naïve DP training on ImageNet, which yields either very low accuracy (2 - 5%) or no privacy (ε=109), this is quite good.

Privacy-accuracy tradeoff for Resnet-18 on ImageNet using large-batch training with transfer learning from Places365.

Next Steps
We hope these early results and source code provide an impetus for other researchers to work on improving DP for ambitious tasks such as ImageNet as a proxy for challenging production-scale tasks. With the much faster DP-SGD on JAX, we urge DP and ML researchers to explore diverse training regimes, model architectures, and algorithms to make DP more practical. To continue advancing the state of the field, we recommend researchers start with a baseline that incorporates full-batch training plus transfer learning.

Acknowledgments
This work was carried out with the support of the Google Visiting Researcher Program while Prof. Geambasu, an Associate Professor with Columbia University, was on sabbatical with Google Research. This work received substantial contributions from Steve Chien, Shuang Song, Andreas Terzis and Abhradeep Guha Thakurta.

Source: Google AI Blog


Applying Differential Privacy to Large Scale Image Classification

Machine learning (ML) models are becoming increasingly valuable for improved performance across a variety of consumer products, from recommendations to automatic image classification. However, despite aggregating large amounts of data, in theory it is possible for models to encode characteristics of individual entries from the training set. For example, experiments in controlled settings have shown that language models trained using email datasets may sometimes encode sensitive information included in the training data and may have the potential to reveal the presence of a particular user’s data in the training set. As such, it is important to prevent the encoding of such characteristics from individual training entries. To these ends, researchers are increasingly employing federated learning approaches.

Differential privacy (DP) provides a rigorous mathematical framework that allows researchers to quantify and understand the privacy guarantees of a system or an algorithm. Within the DP framework, privacy guarantees of a system are usually characterized by a positive parameter ε, called the privacy loss bound, with smaller ε corresponding to better privacy. One usually trains a model with DP guarantees using DP-SGD, a specialized training algorithm that provides DP guarantees for the trained model.

However training with DP-SGD typically has two major drawbacks. First, most existing implementations of DP-SGD are inefficient and slow, which makes it hard to use on large datasets. Second, DP-SGD training often significantly impacts utility (such as model accuracy) to the point that models trained with DP-SGD may become unusable in practice. As a result most DP research papers evaluate DP algorithms on very small datasets (MNIST, CIFAR-10, or UCI) and don’t even try to perform evaluation of larger datasets, such as ImageNet.

In “Toward Training at ImageNet Scale with Differential Privacy”, we share initial results from our ongoing effort to train a large image classification model on ImageNet using DP while maintaining high accuracy and minimizing computational cost. We show that the combination of various training techniques, such as careful choice of the model and hyperparameters, large batch training, and transfer learning from other datasets, can significantly boost accuracy of an ImageNet model trained with DP. To substantiate these discoveries and encourage follow-up research, we are also releasing the associated source code.

Testing Differential Privacy on ImageNet
We choose ImageNet classification as a demonstration of the practicality and efficacy of DP because: (1) it is an ambitious task for DP, for which no prior work shows sufficient progress; and (2) it is a public dataset on which other researchers can operate, so it represents an opportunity to collectively improve the utility of real-life DP training. Classification on ImageNet is challenging for DP because it requires large networks with many parameters. This translates into a significant amount of noise added into the computation, because the noise added scales with the size of the model.

Scaling Differential Privacy with JAX
Exploring multiple architectures and training configurations to research what works for DP can be debilitatingly slow. To streamline our efforts, we used JAX, a high-performance computational library based on XLA that can do efficient auto-vectorization and just-in-time compilation of the mathematical computations. Using these JAX features was previously recommended as a good way to speed up DP-SGD in the context of smaller datasets such as CIFAR-10.

We created our own implementation of DP-SGD on JAX and benchmarked it against the large ImageNet dataset (the code is included in our release). The implementation in JAX was relatively simple and resulted in noticeable performance gains simply because of using the XLA compiler. Compared to other implementations of DP-SGD, such as that in Tensorflow Privacy, the JAX implementation is consistently several times faster. It is typically even faster compared to the custom-built and optimized PyTorch Opacus.

Each step of our DP-SGD implementation takes approximately two forward-backward passes through the network. While this is slower than non-private training, which requires only a single forward-backward pass, it is still the most efficient known approach to train with the per-example gradients necessary for DP-SGD. The graph below shows training runtimes for two models on ImageNet with DP-SGD vs. non-private SGD, each on JAX. Overall, we find DP-SGD on JAX sufficiently fast to run large experiments just by slightly reducing the number of training runs used to find optimal hyperparameters compared to non-private training. This is significantly better than alternatives, such as Tensorflow Privacy, which we found to be ~5x–10x slower on our CIFAR10 and MNIST benchmarks.

Time in seconds per training epoch on ImageNet using a Resnet18 or Resnet50 architecture with 8 V100 GPUs.

Combining Techniques for Improved Accuracy
It is possible that future training algorithms may improve DP’s privacy-utility tradeoff. However, with current algorithms, such as DP-SGD, our experience points to an engineering “bag-of-tricks” approach to make DP more practical on challenging tasks like ImageNet.

Because we can train models faster with JAX, we can iterate quickly and explore multiple configurations to find what works well for DP. We report the following combination of techniques as useful to achieve non-trivial accuracy and privacy on ImageNet:

  • Full-batch training

    Theoretically, it is known that larger minibatch sizes improve the utility of DP-SGD, with full-batch training (i.e., where a full dataset is one batch) giving the best utility [1, 2], and empirical results are emerging to support this theory. Indeed, our experiments demonstrate that increasing the batch size along with the number of training epochs leads to a decrease in ε while still maintaining accuracy. However, training with extremely large batches is non-trivial as the batch cannot fit into GPU/TPU memory. So, we employed virtual large-batch training by accumulating gradients for multiple steps before updating the weights instead of applying gradient updates on each training step.

    Batch size 1024 4 × 1024 16 × 1024 64 × 1024
    Number of epochs 10 40 160 640
    Accuracy 56% 57.5% 57.9% 57.2%
    Privacy loss bound ε 9.8 × 108 6.1 × 107 3.5 × 106 6.7 × 104

  • Transfer learning from public data

    Pre-training on public data followed by DP fine-tuning on private data has previously been shown to improve accuracy on other benchmarks [3, 4]. A question that remains is what public data to use for a given task to optimize transfer learning. In this work we simulate a private/public data split by using ImageNet as "private" data and using Places365, another image classification dataset, as a proxy for “public" data. We pre-trained our models on Places365 before fine-tuning them with DP-SGD on ImageNet. Places365 only has images of landscapes and buildings, not of animals as ImageNet, so it is quite different, making it a good candidate to demonstrate the ability of the model to transfer to a different but related domain.

    We found that transfer learning from Places365 gave us 47.5% accuracy on ImageNet with a reasonable level of privacy (ε = 10). This is low compared to the 70% accuracy of a similar non-private model, but compared to naïve DP training on ImageNet, which yields either very low accuracy (2 - 5%) or no privacy (ε=109), this is quite good.

Privacy-accuracy tradeoff for Resnet-18 on ImageNet using large-batch training with transfer learning from Places365.

Next Steps
We hope these early results and source code provide an impetus for other researchers to work on improving DP for ambitious tasks such as ImageNet as a proxy for challenging production-scale tasks. With the much faster DP-SGD on JAX, we urge DP and ML researchers to explore diverse training regimes, model architectures, and algorithms to make DP more practical. To continue advancing the state of the field, we recommend researchers start with a baseline that incorporates full-batch training plus transfer learning.

Acknowledgments
This work was carried out with the support of the Google Visiting Researcher Program while Prof. Geambasu, an Associate Professor with Columbia University, was on sabbatical with Google Research. This work received substantial contributions from Steve Chien, Shuang Song, Andreas Terzis and Abhradeep Guha Thakurta.

Source: Google AI Blog


Controlling Neural Networks with Rule Representations

Deep neural networks (DNNs) provide more accurate results as the size and coverage of their training data increases. While investing in high-quality and large-scale labeled datasets is one path to model improvement, another is leveraging prior knowledge, concisely referred to as “rules” — reasoning heuristics, equations, associative logic, or constraints. Consider a common example from physics where a model is given the task of predicting the next state in a double pendulum system. While the model may learn to estimate the total energy of the system at a given point in time only from empirical data, it will frequently overestimate the energy unless also provided an equation that reflects the known physical constraints, e.g., energy conservation. The model fails to capture such well-established physical rules on its own. How could one effectively teach such rules so that DNNs absorb the relevant knowledge beyond simply learning from the data?

In “Controlling Neural Networks with Rule Representations”, published at NeurIPS 2021, we present Deep Neural Networks with Controllable Rule Representations (DeepCTRL), an approach used to provide rules for a model agnostic to data type and model architecture that can be applied to any kind of rule defined for inputs and outputs. The key advantage of DeepCTRL is that it does not require retraining to adapt the rule strength. At inference, the user can adjust rule strength based on the desired operation point of accuracy. We also propose a novel input perturbation method, which helps generalize DeepCTRL to non-differentiable constraints. In real-world domains where incorporating rules is critical — such as physics and healthcare — we demonstrate the effectiveness of DeepCTRL in teaching rules for deep learning. DeepCTRL ensures that models follow rules more closely while also providing accuracy gains at downstream tasks, thus improving reliability and user trust in the trained models. Additionally, DeepCTRL enables novel use cases, such as hypothesis testing of the rules on data samples and unsupervised adaptation based on shared rules between datasets.

The benefits of learning from rules are multifaceted:

  • Rules can provide extra information for cases with minimal data, improving the test accuracy.
  • A major bottleneck for widespread use of DNNs is the lack of understanding the rationale behind their reasoning and inconsistencies. By minimizing inconsistencies, rules can improve the reliability of and user trust in DNNs.
  • DNNs are sensitive to slight input changes that are human-imperceptible. With rules, the impact of these changes can be minimized as the model search space is further constrained to reduce underspecification.

Learning Jointly from Rules and Tasks
The conventional approach to implementing rules incorporates them by including them in the calculation of the loss. There are three limitations of this approach that we aim to address: (i) rule strength needs to be defined before learning (thus the trained model cannot operate flexibly based on how much the data satisfies the rule); (ii) rule strength is not adaptable to target data at inference if there is any mismatch with the training setup; and (iii) the rule-based objective needs to be differentiable with respect to learnable parameters (to enable learning from labeled data).

DeepCTRL modifies canonical training by creating rule representations, coupled with data representations, which is the key to enable the rule strength to be controlled at inference time. During training, these representations are stochastically concatenated with a control parameter, indicated by α, into a single representation. The strength of the rule on the output decision can be improved by increasing the value of α. By modifying α at inference, users can control the behavior of the model to adapt to unseen data.

DeepCTRL pairs a data encoder and rule encoder, which produce two latent representations, which are coupled with corresponding objectives. The control parameter α is adjustable at inference to control the relative weight of each encoder.

Integrating Rules via Input Perturbations
Training with rule-based objectives requires the objectives to be differentiable with respect to the learnable parameters of the model. There are many valuable rules that are non-differentiable with respect to input. For example, “higher blood pressure than 140 is likely to lead to cardiovascular disease” is a rule that is hard to be combined with conventional DNNs. We also introduce a novel input perturbation method to generalize DeepCTRL to non-differentiable constraints by introducing small perturbations (random noise) to input features and constructing a rule-based constraint based on whether the outcome is in the desired direction.

Use Cases
We evaluate DeepCTRL on machine learning use cases from physics and healthcare, where utilization of rules is particularly important.

  • Improved Reliability Given Known Principles in Physics
  • We quantify reliability of a model with the verification ratio, which is the fraction of output samples that satisfy the rules. Operating at a better verification ratio could be beneficial, especially if the rules are known to be always valid, as in natural sciences. By adjusting the control parameter α, a higher rule verification ratio, and thus more reliable predictions, can be achieved.

    To demonstrate this, we consider the time-series data generated from double pendulum dynamics with friction from a given initial state. We define the task as predicting the next state of the double pendulum from the current state while imposing the rule of energy conservation. To quantify how much the rule is learned, we evaluate the verification ratio.

    DeepCTRL enables controlling a model's behavior after learning, but without retraining. For the example of a double pendulum, conventional learning imposes no constraints to ensure the model follows physical laws, e.g., conservation of energy. The situation is similar for the case of DeepCTRL where the rule strength is low. So, the total energy of the system predicted at time t+1 ( blue) can sometimes be greater than that measured at time t (red), which is physically disallowed (bottom left). If rule strength in DeepCTRL is high, the model may follow the given rule but lose accuracy (discrepancy between red and blue is larger; bottom right). If rule strength is between the two extremes, the model may achieve higher accuracy (blue curve is close to red) and follow the rule properly (blue curve is lower than red one).

    We compare the performance of DeepCTRL on this task to conventional baselines of training with a fixed rule-based constraint as a regularization term added to the objective, λ. The highest of these regularization coefficients provides the highest verification ratio (shown by the green line in the second graph below), however, the prediction error is slightly worse than that of λ = 0.1 (orange line). We find that the lowest prediction error of the fixed baseline is comparable to that of DeepCTRL, but the highest verification ratio of the fixed baseline is still lower, which implies that DeepCTRL could provide accurate predictions while following the law of energy conservation. In addition, we consider the benchmark of imposing the rule-constraint with Lagrangian Dual Framework (LDF) and demonstrate two results where its hyperparameters are chosen by the lowest mean absolute error (LDF-MAE) and the highest rule verification ratio (LDF-Ratio) on the validation set. The performance of the LDF method is highly sensitive to what the main constraint is and its output is not reliable (black and pink dashed lines).

    Experimental results for the double pendulum task, showing the task-based mean absolute error (MAE), which measures the discrepancy between the ground truth and the model prediction, versus DeepCTRL as a function of the control parameter α. TaskOnly doesn’t have a rule constraint and Task & Rule has different rule strength (λ). LDF enforces rules by solving a constraint optimization problem.
    As above, but showing the verification ratio from different models.
    Experimental results for the double pendulum task showing the current and predicted energy at time t and t + 1, respectively.

    Additionally, the figures above illustrate the advantage DeepCTRL has over conventional approaches. For example, increasing the rule strength λ from 0.1 to 1.0 improves the verification ratio (from 0.7 to 0.9), but does not improve the mean absolute error. Arbitrarily increasing λ will continue to drive the verification ratio closer to 1, but will result in worse accuracy. Thus, finding the optimal value of λ will require many training runs through the baseline model, whereas DeepCTRL can find the optimal value for the control parameter α much more quickly.

  • Adapting to Distribution Shifts in Healthcare
  • The strengths of some rules may differ between subsets of the data. For example, in disease prediction, the correlation between cardiovascular disease and higher blood pressure is stronger for older patients than younger patients. In such situations, when the task is shared but data distribution and the validity of the rule differ between datasets, DeepCTRL can adapt to the distribution shifts by controlling α.

    Exploring this example, we focus on the task of predicting whether cardiovascular disease is present or not using a cardiovascular disease dataset. Given that higher systolic blood pressure is known to be strongly associated with cardiovascular disease, we consider the rule: “higher risk if the systolic blood pressure is higher”. Based on this, we split the patients into two groups: (1) unusual, where a patient has high blood pressure, but no disease or lower blood pressure, but has disease; and (2) usual, where a patient has high blood pressure and disease or low blood pressure, but no disease.

    We demonstrate below that the source data do not always follow the rule, and thus the effect of incorporating the rule can depend on the source data. The test cross entropy, which indicates classification accuracy (lower cross entropy is better), vs. rule strength for source or target datasets with varying usual / unusual ratio are visualized below. The error monotonically increases as α → 1 because the enforcement of the imposed rule, which doesn’t accurately reflect the source data, becomes more strict.

    Test cross entropy vs. rule strength for a source dataset with usual / unusual ratio of 0.30.

    When a trained model is transferred to the target domain, the error can be reduced by controlling α. To demonstrate this, we show three domain-specific datasets, which we call Target 1, 2, and 3. In Target 1, where the majority of patients are from the usual group, as α is increased, the rule-based representation has more weight and the resultant error decreases monotonically.

    As above, but for a Target dataset (1) with a usual / unusual ratio of 0.77.

    When the ratio of usual patients is decreased in Target 2 and 3, the optimal α is an intermediate value between 0 and 1. These demonstrate the capability to adapt the trained model via α.

    As above, but for Target 2 with a usual / unusual ratio of 0.50.
    As above, but for Target 3 with a usual / unusual ratio of 0.40.

Conclusions
Learning from rules can be crucial for constructing interpretable, robust, and reliable DNNs. We propose DeepCTRL, a new methodology used to incorporate rules into data-learned DNNs. DeepCTRL enables controllability of rule strength at inference without retraining. We propose a novel perturbation-based rule encoding method to integrate arbitrary rules into meaningful representations. We demonstrate three use cases of DeepCTRL: improving reliability given known principles, examining candidate rules, and domain adaptation using the rule strength.

Acknowledgements
We greatly appreciate the contributions of Jinsung Yoon, Xiang Zhang, Kihyuk Sohn and Tomas Pfister.

Source: Google AI Blog


Learning to Route by Task for Efficient Inference

Scaling large language models has resulted in significant quality improvements natural language understanding (T5), generation (GPT-3) and multilingual neural machine translation (M4). One common approach to building a larger model is to increase the depth (number of layers) and width (layer dimensionality), simply enlarging existing dimensions of the network. Such dense models take an input sequence (divided into smaller components, called tokens) and pass every token through the full network, activating every layer and parameter. While these large, dense models have achieved state-of-the-art results on multiple natural language processing (NLP) tasks, their training cost increases linearly with model size.

An alternative, and increasingly popular, approach is to build sparsely activated models based on a mixture of experts (MoE) (e.g., GShard-M4 or GLaM), where each token passed to the network follows a separate subnetwork by skipping some of the model parameters. The choice of how to distribute the input tokens to each subnetwork (the “experts”) is determined by small router networks that are trained together with the rest of the network. This allows researchers to increase model size (and hence, performance) without a proportional increase in training cost.

While this is an effective strategy at training time, sending tokens of a long sequence to multiple experts, again makes inference computationally expensive because the experts have to be distributed among a large number of accelerators. For example, serving the 1.2T parameter GLaM model requires 256 TPU-v3 chips. Much like dense models, the number of processors needed to serve an MoE model still scales linearly with respect to the model size, increasing compute requirements while also resulting in significant communication overhead and added engineering complexity.

In “Beyond Distillation: Task-level Mixture-of-Experts for Efficient Inference”, we introduce a method called Task-level Mixture-of-Experts (TaskMoE), that takes advantage of the quality gains of model scaling while still being efficient to serve. Our solution is to train a large multi-task model from which we then extract smaller, stand-alone per-task subnetworks suitable for inference with no loss in model quality and with significantly reduced inference latency. We demonstrate the effectiveness of this method for multilingual neural machine translation (NMT) compared to other mixture of experts models and to models compressed using knowledge distillation.

Training Large Sparsely Activated Models with Task Information
We train a sparsely activated model, where router networks learn to send tokens of each task-specific input to different subnetworks of the model associated with the task of interest. For example, in the case of multilingual NMT, every token of a given language is routed to the same subnetwork. This differs from other recent approaches, such as the sparsely gated mixture of expert models (e.g., TokenMoE), where router networks learn to send different tokens in an input to different subnetworks independent of task.

Inference: Bypassing Distillation by Extracting Subnetworks
A consequence of this difference in training between TaskMoE and models like TokenMoE is in how we approach inference. Because TokenMoE follows the practice of distributing tokens of the same task to many experts at both training and inference time, it is still computationally expensive at inference.

For TaskMoE, we dedicate a smaller subnetwork to a single task identity during training and inference. At inference time, we extract subnetworks by discarding unused experts for each task. TaskMoE and its variants enable us to train a single large multi-task network and then use a separate subnetwork at inference time for each task without using any additional compression methods post-training. We illustrate the process of training a TaskMoE network and then extracting per-task subnetworks for inference below.

During training, tokens of the same language are routed to the same expert based on language information (either source, target or both) in task-based MoE. Later, during inference we extract subnetworks for each task and discard unused experts.

To demonstrate this approach, we train models based on the Transformer architecture. Similar to GShard-M4 and GLaM, we replace the feedforward network of every other transformer layer with a Mixture-of-Experts (MoE) layer that consists of multiple identical feedforward networks, the “experts”. For each task, the routing network, trained along with the rest of the model, keeps track of the task identity for all input tokens and chooses a certain number of experts per layer (two in this case) to form the task-specific subnetwork. The baseline dense Transformer model has 143M parameters and 6 layers on both the encoder and decoder. The TaskMoE and TokenMoE that we train are also both 6 layers deep but with 32 experts for every MoE layer and have a total of 533M parameters. We train our models using publicly available WMT datasets, with over 431M sentences across 30 language pairs from different language families and scripts. We point the reader to the full paper for further details.

Results
In order to demonstrate the advantage of using TaskMoE at inference time, we compare the throughput, or the number of tokens decoded per second, for TaskMoE, TokenMoE, and a baseline dense model. Once the subnetwork for each task is extracted, TaskMoE is 7x smaller than the 533M parameter TokenMoE model, and it can be served on a single TPUv3 core, instead of 64 cores required for TokenMoE. We see that TaskMoE has a peak throughput twice as high as that of TokenMoE models. In addition, on inspecting the TokenMoE model, we find that 25% of the inference time has been spent in inter-device communication, while virtually no time is spent in communication by TaskMoE.
Comparing the throughput of TaskMoE with TokenMoE across different batch sizes. The maximum batch size for TokenMoE is 1024 as opposed to 4096 for TaskMoE and the dense baseline model. Here, TokenMoE has one instance distributed across 64 TPUv3 cores, while TaskMoE and the baseline model have one instance on each of the 64 cores.

A popular approach to building a smaller network that still performs well is through knowledge distillation, in which a large teacher model trains a smaller student model with the goal of matching the teacher’s performance. However, this method comes at the cost of additional computation needed to train the student from the teacher. So, we also compare TaskMoE to a baseline TokenMoE model that we compress using knowledge distillation. The compressed TokenMoE model has a size comparable to the per-task subnetwork extracted from TaskMoE.

We find that in addition to being a simpler method that does not need any additional training, TaskMoE improves upon a distilled TokenMoE model by 2.1 BLEU on average across all languages in our multilingual translation model. We note that distillation retains 43% of the performance gains achieved from scaling a dense multilingual model to a TokenMoE, whereas extracting the smaller subnetwork from the TaskMoE model results in no loss of quality.

BLEU scores (higher is better) comparing a distilled TokenMoE model to the TaskMoE and TokenMoE models with 12 layers (6 on the encoder and 6 on the decoder) and 32 experts. While both approaches improve upon a multilingual dense baseline, TaskMoE improves upon the baseline by 3.1 BLEU on average while distilling from TokenMoE improves upon the baseline by 1.0 BLEU on average.

Next Steps
The quality improvements often seen with scaling machine learning models has incentivized the research community to work toward advancing scaling technology to enable efficient training of large models. The emerging need to train models capable of generalizing to multiple tasks and modalities only increases the need for scaling models even further. However, the practicality of serving these large models remains a major challenge. Efficiently deploying large models is an important direction of research, and we believe TaskMoE is a promising step towards more inference friendly algorithms that retain the quality gains of scaling.

Acknowledgements
We would like to first thank our coauthors - Yanping Huang, Ankur Bapna, Maxim Krikun, Dmitry Lepikhin and Minh-Thang Luong. We would also like to thank Wolfgang Macherey, Yuanzhong Xu, Zhifeng Chen and Macduff Richard Hughes for their helpful feedback. Special thanks to the Translate and Brain teams for their useful input and discussions, and the entire GShard development team for their foundational contributions to this project. We would also like to thank Tom Small for creating the animations for the blog post.

Source: Google AI Blog


DeepNull: an open-source method to improve the discovery power of genetic association studies

In our paper “DeepNull models non-linear covariate effects to improve phenotypic prediction and association power,” we proposed a new method, DeepNull, to model the complex relationship between covariate effects on phenotypes to improve Genome-wide association studies (GWAS) results. We have released DeepNull as open source software, with a Colab notebook tutorial for its use.

Human Genetics 101

Each individual’s genetic data carries health information such as why certain individuals have a lower risk of developing skin cancer compared to others or why certain drugs differ in effectiveness between individuals. Genetic data is encoded in the human genome—a DNA sequence—composed of a 3 billion long chain built from four possible nucleotides (A, C, G, and T). Only a small subset of the genome (~4-5 million positions) varies between two individuals. One of the goals of genetic studies is to detect variants that are associated with different phenotypes (e.g., risk of diseases such as Glaucoma or observed phenotypic values such as high-density lipoprotein (HDL), low-density lipoproteins (LDL), height, etc).

Genome-wide association studies

GWAS are used to associate genetic variants with complex traits and diseases. To more accurately determine an association strength between genotype and phenotype, the interactions between phenotypes (such as age and sex) and principal components (PCs) of genotypes, must be adjusted for as covariates. Covariate adjustment in GWAS can increase precision and correct for confounding. In the linear model setting, adjustment for a covariate will improve precision (i.e., statistical power) if the distribution of the phenotype differs across levels of the covariate. For example, when performing GWAS on height, males and females have different means. All state of the art methods (e.g., BOLT-LMM, regenie) perform GWAS assuming that the effect of genotypes and covariates to phenotype is linear and additive. However, we know that the assumption of linear and additive contributions of covariates often does not reflect underlying biology, so we sought a method to more comprehensively model and adjust for the interactions between phenotypes for GWAS.

DeepNull method overview

We proposed a new method, DeepNull, to relax the linear assumption of covariate effects on phenotypes. DeepNull trains a deep neural network (DNN) to predict phenotype using all covariates in a 5-fold cross-validation. After training the DeepNull model, we make phenotype predictions for all individuals and add this prediction as one additional covariate in the association test. Major advantages of DeepNull are its simplicity to use and that it requires only a minimal change to existing GWAS pipeline implementations. In other words, to use DeepNull, we just need to add one additional covariate, which is computed by DeepNull, to the existing pipeline to perform GWAS.

DeepNull improves statistical power

We simulated data under different genetic architectures (genetic conditions) to first check that DeepNull controls type I error and then compare DeepNull statistical power with current state of the art methods (hereafter referred to as “Baseline”). First, we simulated data under genetic architectures where covariates have a linear effect on phenotype and observed that both Baseline and DeepNull have tight control of type I error. It is interesting that DeepNull power does not decrease compared to Baseline under a setting in which covariates have only a linear effect on phenotype. Next, we simulated data under genetic architectures where covariates have non-linear effects on phenotype. Both Baseline and DeepNull have tight control of type I error while DeepNull increases the statistical power depending on the genetic architecture. We observed that for certain genetic architectures, DeepNull increases the statistical power up to 20%. Below, we compare the -log p-value of test statistics computed from DeepNull versus Baseline for Apolipoprotein B (ApoB) levels obtained from UK Biobank:
Figure 1. Significance level comparison of DeepNull vs Baseline. X-axis is the -log p-value of Baseline and Y-axis is the -log p-value of DeepNull. The orange dots indicate variants that are significant for Baseline but not significant for DeepNull and green dots indicate variants that are significant for DeepNull but not significant for Baseline.

DeepNull improves phenotype prediction

We applied DeepNull to predict phenotypes by utilizing polygenic risk score (PRS) and existing covariates such as age and sex. We considered 10 phenotypes obtained from UK Biobank. We observed that DeepNull on average increased the phenotype prediction (R2 where R is Pearson correlation) by 23%. More strikingly, in the case of Glaucoma, referral probability that is computed from the fundus images (Phene et al. Ophthalmology 2019, Alipanahi et al AJHG 2021), DeepNull improves the phenotype prediction by 83.4% and in the case of LDL, DeepNull improves the phenotype prediction by 40.3%. The summary of DeepNull results versus Baseline are shown in figure 2 below:

 
 

Figure 2. DeepNull improves phenotype prediction compared to Baseline. The Y-axis is the R2 where R is the Pearson’s correlation between true and predicted value of phenotypes. Phenotypic abbreviations: alkaline phosphatase (ALP), alanine aminotransferase (ALT), aspartate aminotransferase (AST), apolipoprotein B (ApoB), glaucoma referral probability (GRP), LDLcholesterol (LDL), sex hormone-binding globulin (SHBG), and triglycerides (TG).

Conclusion

We proposed a new framework, DeepNull, that can model the nonlinear effect of covariates on phenotypes when such nonlinearity exists. We show that DeepNull can substantially improve phenotype prediction. In addition, we show that DeepNull achieves results similar to a standard GWAS when the effect of covariate on the phenotype is linear and can significantly outperform a standard GWAS when the covariate effects are nonlinear. DeepNull is open source and is available for download from GitHub or installation via PyPI.

By Farhad Hormozdiari and Andrew Carroll – Genomics team in HealthAI

Acknowledgments

This blog summarizes the work of the following Google contributors, who we would like to thank: Zachary R. McCaw, Thomas Colthurst, Ted Yun, Nick Furlotte, Babak Alipanahi, and Cory Y. McLean. In addition, we would like to thank Alkes Price, Babak Behsaz, and Justin Cosentino for their invaluable comments and suggestions.