Learning the importance of training data under concept drift

The constantly changing nature of the world around us poses a significant challenge for the development of AI models. Often, models are trained on longitudinal data with the hope that the training data used will accurately represent inputs the model may receive in the future. More generally, the default assumption that all training data are equally relevant often breaks in practice. For example, the figure below shows images from the CLEAR nonstationary learning benchmark, and it illustrates how visual features of objects evolve significantly over a 10 year span (a phenomenon we refer to as slow concept drift), posing a challenge for object categorization models.

Sample images from the CLEAR benchmark. (Adapted from Lin et al.)

Alternative approaches, such as online and continual learning, repeatedly update a model with small amounts of recent data in order to keep it current. This implicitly prioritizes recent data, as the learnings from past data are gradually erased by subsequent updates. However in the real world, different kinds of information lose relevance at different rates, so there are two key issues: 1) By design they focus exclusively on the most recent data and lose any signal from older data that is erased. 2) Contributions from data instances decay uniformly over time irrespective of the contents of the data.

In our recent work, “Instance-Conditional Timescales of Decay for Non-Stationary Learning”, we propose to assign each instance an importance score during training in order to maximize model performance on future data. To accomplish this, we employ an auxiliary model that produces these scores using the training instance as well as its age. This model is jointly learned with the primary model. We address both the above challenges and achieve significant gains over other robust learning methods on a range of benchmark datasets for nonstationary learning. For instance, on a recent large-scale benchmark for nonstationary learning (~39M photos over a 10 year period), we show up to 15% relative accuracy gains through learned reweighting of training data.

The challenge of concept drift for supervised learning

To gain quantitative insight into slow concept drift, we built classifiers on a recent photo categorization task, comprising roughly 39M photographs sourced from social media websites over a 10 year period. We compared offline training, which iterated over all the training data multiple times in random order, and continual training, which iterated multiple times over each month of data in sequential (temporal) order. We measured model accuracy both during the training period and during a subsequent period where both models were frozen, i.e., not updated further on new data (shown below). At the end of the training period (left panel, x-axis = 0), both approaches have seen the same amount of data, but show a large performance gap. This is due to catastrophic forgetting, a problem in continual learning where a model’s knowledge of data from early on in the training sequence is diminished in an uncontrolled manner. On the other hand, forgetting has its advantages — over the test period (shown on the right), the continual trained model degrades much less rapidly than the offline model because it is less dependent on older data. The decay of both models’ accuracy in the test period is confirmation that the data is indeed evolving over time, and both models become increasingly less relevant.

Comparing offline and continually trained models on the photo classification task.

Time-sensitive reweighting of training data

We design a method combining the benefits of offline learning (the flexibility of effectively reusing all available data) and continual learning (the ability to downplay older data) to address slow concept drift. We build upon offline learning, then add careful control over the influence of past data and an optimization objective, both designed to reduce model decay in the future.

Suppose we wish to train a model, M, given some training data collected over time. We propose to also train a helper model that assigns a weight to each point based on its contents and age. This weight scales the contribution from that data point in the training objective for M. The objective of the weights is to improve the performance of M on future data.

In our work, we describe how the helper model can be meta-learned, i.e., learned alongside M in a manner that helps the learning of the model M itself. A key design choice of the helper model is that we separated out instance- and age-related contributions in a factored manner. Specifically, we set the weight by combining contributions from multiple different fixed timescales of decay, and learn an approximate “assignment” of a given instance to its most suited timescales. We find in our experiments that this form of the helper model outperforms many other alternatives we considered, ranging from unconstrained joint functions to a single timescale of decay (exponential or linear), due to its combination of simplicity and expressivity. Full details may be found in the paper.

Instance weight scoring

The top figure below shows that our learned helper model indeed up-weights more modern-looking objects in the CLEAR object recognition challenge; older-looking objects are correspondingly down-weighted. On closer examination (bottom figure below, gradient-based feature importance assessment), we see that the helper model focuses on the primary object within the image, as opposed to, e.g., background features that may spuriously be correlated with instance age.

Sample images from the CLEAR benchmark (camera & computer categories) assigned the highest and lowest weights respectively by our helper model.

Feature importance analysis of our helper model on sample images from the CLEAR benchmark.


Gains on large-scale data

We first study the large-scale photo categorization task (PCAT) on the YFCC100M dataset discussed earlier, using the first five years of data for training and the next five years as test data. Our method (shown in red below) improves substantially over the no-reweighting baseline (black) as well as many other robust learning techniques. Interestingly, our method deliberately trades off accuracy on the distant past (training data unlikely to reoccur in the future) in exchange for marked improvements in the test period. Also, as desired, our method degrades less than other baselines in the test period.

Comparison of our method and relevant baselines on the PCAT dataset.

Broad applicability

We validated our findings on a wide range of nonstationary learning challenge datasets sourced from the academic literature (see 1, 2, 3, 4 for details) that spans data sources and modalities (photos, satellite images, social media text, medical records, sensor readings, tabular data) and sizes (ranging from 10k to 39M instances). We report significant gains in the test period when compared to the nearest published benchmark method for each dataset (shown below). Note that the previous best-known method may be different for each dataset. These results showcase the broad applicability of our approach.

Performance gain of our method on a variety of tasks studying natural concept drift. Our reported gains are over the previous best-known method for each dataset.

Extensions to continual learning

Finally, we consider an interesting extension of our work. The work above described how offline learning can be extended to handle concept drift using ideas inspired by continual learning. However, sometimes offline learning is infeasible — for example, if the amount of training data available is too large to maintain or process. We adapted our approach to continual learning in a straightforward manner by applying temporal reweighting within the context of each bucket of data being used to sequentially update the model. This proposal still retains some limitations of continual learning, e.g., model updates are performed only on most-recent data, and all optimization decisions (including our reweighting) are only made over that data. Nevertheless, our approach consistently beats regular continual learning as well as a wide range of other continual learning algorithms on the photo categorization benchmark (see below). Since our approach is complementary to the ideas in many baselines compared here, we anticipate even larger gains when combined with them.

Results of our method adapted to continual learning, compared to the latest baselines.


We addressed the challenge of data drift in learning by combining the strengths of previous approaches — offline learning with its effective reuse of data, and continual learning with its emphasis on more recent data. We hope that our work helps improve model robustness to concept drift in practice, and generates increased interest and new ideas in addressing the ubiquitous problem of slow concept drift.


We thank Mike Mozer for many interesting discussions in the early phase of this work, as well as very helpful advice and feedback during its development.

Intervening on early readouts for mitigating spurious features and simplicity bias

Machine learning models in the real world are often trained on limited data that may contain unintended statistical biases. For example, in the CELEBA celebrity image dataset, a disproportionate number of female celebrities have blond hair, leading to classifiers incorrectly predicting “blond” as the hair color for most female faces — here, gender is a spurious feature for predicting hair color. Such unfair biases could have significant consequences in critical applications such as medical diagnosis.

Surprisingly, recent work has also discovered an inherent tendency of deep networks to amplify such statistical biases, through the so-called simplicity bias of deep learning. This bias is the tendency of deep networks to identify weakly predictive features early in the training, and continue to anchor on these features, failing to identify more complex and potentially more accurate features.

With the above in mind, we propose simple and effective fixes to this dual challenge of spurious features and simplicity bias by applying early readouts and feature forgetting. First, in “Using Early Readouts to Mediate Featural Bias in Distillation”, we show that making predictions from early layers of a deep network (referred to as “early readouts”) can automatically signal issues with the quality of the learned representations. In particular, these predictions are more often wrong, and more confidently wrong, when the network is relying on spurious features. We use this erroneous confidence to improve outcomes in model distillation, a setting where a larger “teacher” model guides the training of a smaller “student” model. Then in “Overcoming Simplicity Bias in Deep Networks using a Feature Sieve”, we intervene directly on these indicator signals by making the network “forget” the problematic features and consequently look for better, more predictive features. This substantially improves the model’s ability to generalize to unseen domains compared to previous approaches. Our AI Principles and our Responsible AI practices guide how we research and develop these advanced applications and help us address the challenges posed by statistical biases.

Animation comparing hypothetical responses from two models trained with and without the feature sieve.

Early readouts for debiasing distillation

We first illustrate the diagnostic value of early readouts and their application in debiased distillation, i.e., making sure that the student model inherits the teacher model’s resilience to feature bias through distillation. We start with a standard distillation framework where the student is trained with a mixture of label matching (minimizing the cross-entropy loss between student outputs and the ground-truth labels) and teacher matching (minimizing the KL divergence loss between student and teacher outputs for any given input).

Suppose one trains a linear decoder, i.e., a small auxiliary neural network named as Aux, on top of an intermediate representation of the student model. We refer to the output of this linear decoder as an early readout of the network representation. Our finding is that early readouts make more errors on instances that contain spurious features, and further, the confidence on those errors is higher than the confidence associated with other errors. This suggests that confidence on errors from early readouts is a fairly strong, automated indicator of the model’s dependence on potentially spurious features.

Illustrating the usage of early readouts (i.e., output from the auxiliary layer) in debiasing distillation. Instances that are confidently mispredicted in the early readouts are upweighted in the distillation loss.

We used this signal to modulate the contribution of the teacher in the distillation loss on a per-instance basis, and found significant improvements in the trained student model as a result.

We evaluated our approach on standard benchmark datasets known to contain spurious correlations (Waterbirds, CelebA, CivilComments, MNLI). Each of these datasets contain groupings of data that share an attribute potentially correlated with the label in a spurious manner. As an example, the CelebA dataset mentioned above includes groups such as {blond male, blond female, non-blond male, non-blond female}, with models typically performing the worst on the {non-blond female} group when predicting hair color. Thus, a measure of model performance is its worst group accuracy, i.e., the lowest accuracy among all known groups present in the dataset. We improved the worst group accuracy of student models on all datasets; moreover, we also improved overall accuracy in three of the four datasets, showing that our improvement on any one group does not come at the expense of accuracy on other groups. More details are available in our paper.

Comparison of Worst Group Accuracies of different distillation techniques relative to that of the Teacher model. Our method outperforms other methods on all datasets.

Overcoming simplicity bias with a feature sieve

In a second, closely related project, we intervene directly on the information provided by early readouts, to improve feature learning and generalization. The workflow alternates between identifying problematic features and erasing identified features from the network. Our primary hypothesis is that early features are more prone to simplicity bias, and that by erasing (“sieving”) these features, we allow richer feature representations to be learned.

Training workflow with feature sieve. We alternate between identifying problematic features (using training iteration) and erasing them from the network (using forgetting iteration).

We describe the identification and erasure steps in more detail:

  • Identifying simple features: We train the primary model and the readout model (AUX above) in conventional fashion via forward- and back-propagation. Note that feedback from the auxiliary layer does not back-propagate to the main network. This is to force the auxiliary layer to learn from already-available features rather than create or reinforce them in the main network.
  • Applying the feature sieve: We aim to erase the identified features in the early layers of the neural network with the use of a novel forgetting loss, Lf , which is simply the cross-entropy between the readout and a uniform distribution over labels. Essentially, all information that leads to nontrivial readouts are erased from the primary network. In this step, the auxiliary network and upper layers of the main network are kept unchanged.

We can control specifically how the feature sieve is applied to a given dataset through a small number of configuration parameters. By changing the position and complexity of the auxiliary network, we control the complexity of the identified- and erased features. By modifying the mixing of learning and forgetting steps, we control the degree to which the model is challenged to learn more complex features. These choices, which are dataset-dependent, are made via hyperparameter search to maximize validation accuracy, a standard measure of generalization. Since we include “no-forgetting” (i.e., the baseline model) in the search space, we expect to find settings that are at least as good as the baseline.

Below we show features learned by the baseline model (middle row) and our model (bottom row) on two benchmark datasets — biased activity recognition (BAR) and animal categorization (NICO). Feature importance was estimated using post-hoc gradient-based importance scoring (GRAD-CAM), with the orange-red end of the spectrum indicating high importance, while green-blue indicates low importance. Shown below, our trained models focus on the primary object of interest, whereas the baseline model tends to focus on background features that are simpler and spuriously correlated with the label.

Feature importance scoring using GRAD-CAM on activity recognition (BAR) and animal categorization (NICO) generalization benchmarks. Our approach (last row) focuses on the relevant objects in the image, whereas the baseline (ERM; middle row) relies on background features that are spuriously correlated with the label.

Through this ability to learn better, generalizable features, we show substantial gains over a range of relevant baselines on real-world spurious feature benchmark datasets: BAR, CelebA Hair, NICO and ImagenetA, by margins up to 11% (see figure below). More details are available in our paper.

Our feature sieve method improves accuracy by significant margins relative to the nearest baseline for a range of feature generalization benchmark datasets.


We hope that our work on early readouts and their use in feature sieving for generalization will both spur the development of a new class of adversarial feature learning approaches and help improve the generalization capability and robustness of deep learning systems.


The work on applying early readouts to debiasing distillation was conducted in collaboration with our academic partners Durga Sivasubramanian, Anmol Reddy and Prof. Ganesh Ramakrishnan at IIT Bombay. We extend our sincere gratitude to Praneeth Netrapalli and Anshul Nasery for their feedback and recommendations. We are also grateful to Nishant Jain, Shreyas Havaldar, Rachit Bansal, Kartikeya Badola, Amandeep Kaur and the whole cohort of pre-doctoral researchers at Google Research India for taking part in research discussions. Special thanks to Tom Small for creating the animation used in this post.

An ML-based approach to better characterize lung diseases

The combination of the environment an individual experiences and their genetic predispositions determines the majority of their risk for various diseases. Large national efforts, such as the UK Biobank, have created large, public resources to better understand the links between environment, genetics, and disease. This has the potential to help individuals better understand how to stay healthy, clinicians to treat illnesses, and scientists to develop new medicines.

One challenge in this process is how we make sense of the vast amount of clinical measurements — the UK Biobank has many petabytes of imaging, metabolic tests, and medical records spanning 500,000 individuals. To best use this data, we need to be able to represent the information present as succinct, informative labels about meaningful diseases and traits, a process called phenotyping. That is where we can use the ability of ML models to pick up on subtle intricate patterns in large amounts of data.

We’ve previously demonstrated the ability to use ML models to quickly phenotype at scale for retinal diseases. Nonetheless, these models were trained using labels from clinician judgment, and access to clinical-grade labels is a limiting factor due to the time and expense needed to create them.

In “Inference of chronic obstructive pulmonary disease with deep learning on raw spirograms identifies new genetic loci and improves risk models”, published in Nature Genetics, we’re excited to highlight a method for training accurate ML models for genetic discovery of diseases, even when using noisy and unreliable labels. We demonstrate the ability to train ML models that can phenotype directly from raw clinical measurement and unreliable medical record information. This reduced reliance on medical domain experts for labeling greatly expands the range of applications for our technique to a panoply of diseases and has the potential to improve their prevention, diagnosis, and treatment. We showcase this method with ML models that can better characterize lung function and chronic obstructive pulmonary disease (COPD). Additionally, we show the usefulness of these models by demonstrating a better ability to identify genetic variants associated with COPD, improved understanding of the biology behind the disease, and successful prediction of outcomes associated with COPD.

ML for deeper understanding of exhalation

For this demonstration, we focused on COPD, the third leading cause of worldwide death in 2019, in which airway inflammation and impeded airflow can progressively reduce lung function. Lung function for COPD and other diseases is measured by recording an individual’s exhalation volume over time (the record is called a spirogram; see an example below). Although there are guidelines (called GOLD) for determining COPD status from exhalation, these use only a few, specific data points in the curve and apply fixed thresholds to those values. Much of the rich data from these spirograms is discarded in this analysis of lung function.

We reasoned that ML models trained to classify spirograms would be able to use the rich data present more completely and result in more accurate and comprehensive measures of lung function and disease, similar to what we have seen in other classification tasks like mammography or histology. We trained ML models to predict whether an individual has COPD using the full spirograms as inputs.

Spirometry and COPD status overview. Spirograms from lung function test showing a forced expiratory volume-time spirogram (left), a forced expiratory flow-time spirogram (middle), and an interpolated forced expiratory flow-volume spirogram (right). The profile of individuals w/o COPD is different.

The common method of training models for this problem, supervised learning, requires samples to be associated with labels. Determining those labels can require the effort of very time-constrained experts. For this work, to show that we do not necessarily need medically graded labels, we decided to use a variety of widely available sources of medical record information to create those labels without medical expert review. These labels are less reliable and noisy for two reasons. First, there are gaps in the medical records of individuals because they use multiple health services. Second, COPD is often undiagnosed, meaning many with the disease will not be labeled as having it even if we compile the complete medical records. Nonetheless, we trained a model to predict these noisy labels from the spirogram curves and treat the model predictions as a quantitative COPD liability or risk score.

Noisy COPD status labels were derived using various medical record sources (clinical data). A COPD liability model is then trained to predict COPD status from raw flow-volume spirograms.

Predicting COPD outcomes

We then investigated whether the risk scores produced by our model could better predict a variety of binary COPD outcomes (for example, an individual’s COPD status, whether they were hospitalized for COPD or died from it). For comparison, we benchmarked the model relative to expert-defined measurements required to diagnose COPD, specifically FEV1/FVC, which compares specific points on the spirogram curve with a simple mathematical ratio. We observed an improvement in the ability to predict these outcomes as seen in the precision-recall curves below.

Precision-recall curves for COPD status and outcomes for our ML model (green) compared to traditional measures. Confidence intervals are shown by lighter shading.

We also observed that separating populations by their COPD model score was predictive of all-cause mortality. This plot suggests that individuals with higher COPD risk are more likely to die earlier from any causes and the risk probably has implications beyond just COPD.

Survival analysis of a cohort of UK Biobank individuals stratified by their COPD model’s predicted risk quartile. The decrease of the curve indicates individuals in the cohort dying over time. For example, p100 represents the 25% of the cohort with greatest predicted risk, while p50 represents the 2nd quartile.

Identifying the genetic links with COPD

Since the goal of large scale biobanks is to bring together large amounts of both phenotype and genetic data, we also performed a test called a genome-wide association study (GWAS) to identify the genetic links with COPD and genetic predisposition. A GWAS measures the strength of the statistical association between a given genetic variant — a change in a specific position of DNA — and the observations (e.g., COPD) across a cohort of cases and controls. Genetic associations discovered in this manner can inform drug development that modifies the activity or products of a gene, as well as expand our understanding of the biology for a disease.

We showed with our ML-phenotyping method that not only do we rediscover almost all known COPD variants found by manual phenotyping, but we also find many novel genetic variants significantly associated with COPD. In addition, we see good agreement on the effect sizes for the variants discovered by both our ML approach and the manual one (R2=0.93), which provides strong evidence for validity of the newly found variants.

Left: A plot comparing the statistical power of genetic discovery using the labels for our ML model (y-axis) with the statistical power of the manual labels from a traditional study (x-axis). A value above the y = x line indicates greater statistical power in our method. Green points indicate significant findings in our method that are not found using the traditional approach. Orange points are significant in the traditional approach but not ours. Blue points are significant in both. Right: Estimates of the association effect between our method (y-axis) and traditional method (x-axis). Note that the relative values between studies are comparable but the absolute numbers are not.

Finally, our collaborators at Harvard Medical School and Brigham and Women’s Hospital further examined the plausibility of these findings by providing insights into the possible biological role of the novel variants in development and progression of COPD (you can see more discussion on these insights in the paper).


We demonstrated that our earlier methods for phenotyping with ML can be expanded to a wide range of diseases and can provide novel and valuable insights. We made two key observations by using this to predict COPD from spirograms and discovering new genetic insights. First, domain knowledge was not necessary to make predictions from raw medical data. Interestingly, we showed the raw medical data is probably underutilized and the ML model can find patterns in it that are not captured by expert-defined measurements. Second, we do not need medically graded labels; instead, noisy labels defined from widely available medical records can be used to generate clinically predictive and genetically informative risk scores. We hope that this work will broadly expand the ability of the field to use noisy labels and will improve our collective understanding of lung function and disease.


This work is the combined output of multiple contributors and institutions. We thank all contributors: Justin Cosentino, Babak Alipanahi, Zachary R. McCaw, Cory Y. McLean, Farhad Hormozdiari (Google), Davin Hill (Northeastern University), Tae-Hwi Schwantes-An and Dongbing Lai (Indiana University), Brian D. Hobbs and Michael H. Cho (Brigham and Women’s Hospital, and Harvard Medical School). We also thank Ted Yun and Nick Furlotte for reviewing the manuscript, Greg Corrado and Shravya Shetty for support, and Howard Yang, Kavita Kulkarni, and Tammi Huynh for helping with publication logistics.

UniPi: Learning universal policies via text-guided video generation

Building models that solve a diverse set of tasks has become a dominant paradigm in the domains of vision and language. In natural language processing, large pre-trained models, such as PaLM, GPT-3 and Gopher, have demonstrated remarkable zero-shot learning of new language tasks. Similarly, in computer vision, models like CLIP and Flamingo have shown robust performance on zero-shot classification and object recognition. A natural next step is to use such tools to construct agents that can complete different decision-making tasks across many environments.

However, training such agents faces the inherent challenge of environmental diversity, since different environments operate with distinct state action spaces (e.g., the joint space and continuous controls in MuJoCo are fundamentally different from the image space and discrete actions in Atari). This environmental diversity hampers knowledge sharing, learning, and generalization across tasks and environments. Furthermore, it is difficult to construct reward functions across environments, as different tasks generally have different notions of success.

In “Learning Universal Policies via Text-Guided Video Generation”, we propose a Universal Policy (UniPi) that addresses environmental diversity and reward specification challenges. UniPi leverages text for expressing task descriptions and video (i.e., image sequences) as a universal interface for conveying action and observation behavior in different environments. Given an input image frame paired with text describing a current goal (i.e., the next high-level step), UniPi uses a novel video generator (trajectory planner) to generate video with snippets of what an agent’s trajectory should look like to achieve that goal. The generated video is fed into an inverse dynamics model that extracts underlying low-level control actions, which are then executed in simulation or by a real robot agent. We demonstrate that UniPi enables the use of language and video as a universal control interface for generalizing to novel goals and tasks across diverse environments.

Video policies generated by UniPi.
UniPi may be applied to downstream multi-task settings that require combinatorial language generalization, long-horizon planning, or internet-scale knowledge. In the bottom example, UniPi takes the image of the white robot arm from the internet and generates video snippets according to the text description of the goal.

UniPi implementation

To generate a valid and executable plan, a text-to-video model must synthesize a constrained video plan starting at the current observed image. We found it more effective to explicitly constrain a video synthesis model during training (as opposed to only constraining videos at sampling time) by providing the first frame of each video as explicit conditioning context.

At a high level, UniPi has four major components: 1) consistent video generation with first-frame tiling, 2) hierarchical planning through temporal super resolution, 3) flexible behavior synthesis, and 4) task-specific action adaptation. We explain the implementation and benefit of each component in detail below.

Video generation through tiling

Existing text-to-video models like Imagen typically generate videos where the underlying environment state changes significantly throughout the duration. To construct an accurate trajectory planner, it is important that the environment remains consistent across all time points. We enforce environment consistency in conditional video synthesis by providing the observed image as additional context when denoising each frame in the synthesized video. To achieve context conditioning, UniPi directly concatenates each intermediate frame sampled from noise with the conditioned observed image across sampling steps, which serves as a strong signal to maintain the underlying environment state across time.

Text-conditional video generation enables UniPi to train general purpose policies on a wide range of data sources (simulated, real robots and YouTube).

Hierarchical planning

When constructing plans in high-dimensional environments with long time horizons, directly generating a set of actions to reach a goal state quickly becomes intractable due to the exponential growth of the underlying search space as the plan gets longer. Planning methods often circumvent this issue by leveraging a natural hierarchy in planning. Specifically, planning methods first construct coarse plans (the intermediate key frames spread out across time) operating on low-dimensional states and actions, which are then refined into plans in the underlying state and action spaces.

Similar to planning, our conditional video generation procedure exhibits a natural temporal hierarchy. UniPi first generates videos at a coarse level by sparsely sampling videos (“abstractions”) of desired agent behavior along the time axis. UniPi then refines the videos to represent valid behavior in the environment by super-resolving videos across time. Meanwhile, coarse-to-fine super-resolution further improves consistency via interpolation between frames.

Given an input observation and text instruction, we plan a set of images representing agent behavior. Images are converted to actions using an inverse dynamics model.

Flexible behavioral modulation

When planning a sequence of actions for a given sub-goal, one can readily incorporate external constraints to modulate a generated plan. Such test-time adaptability can be implemented by composing a probabilistic prior incorporating properties of the desired plan to specify desired constraints across the synthesized action trajectory, which is also compatible with UniPi. In particular, the prior can be specified using a learned classifier on images to optimize a particular task, or as a Dirac delta distribution on a particular image to guide a plan towards a particular set of states. To train the text-conditioned video generation model, we utilize the video diffusion algorithm, where pre-trained language features from the Text-To-Text Transfer Transformer (T5) are encoded.

Task-specific action adaptation

Given a set of synthesized videos, we train a small task-specific inverse dynamics model to translate frames into a set of low-level control actions. This is independent from the planner and can be done on a separate, smaller and potentially suboptimal dataset generated by a simulator.

Given the input frame and text description of the current goal, the inverse dynamics model synthesizes image frames and generates a control action sequence that predicts the corresponding future actions. An agent then executes inferred low-level control actions via closed-loop control.

Capabilities and evaluation of UniPi

We measure the task success rate on novel language-based goals, and find that UniPi generalizes well to both seen and novel combinations of language prompts, compared to baselines such as Transformer BC, Trajectory Transformer (TT), and Diffuser.

UniPi generalizes well to both seen and novel combinations of language prompts in Place (e.g., “place X in Y”) and Relation (e.g., “place X to the left of Y”) tasks.

Below, we illustrate generated videos on unseen combinations of goals. UniPi is able to synthesize a diverse set of behaviors that satisfy unseen language subgoals:

Generated videos for unseen language goals at test time.

Multi-environment transfer

We measure the task success rate of UniPi and baselines on novel tasks not seen during training. UniPi again outperforms the baselines by a large margin:

UniPi generalizes well to new environments when trained on a set of different multi-task environments.

Below, we illustrate generated videos on unseen tasks. UniPi is further able to synthesize a diverse set of behaviors that satisfy unseen language tasks:

Generated video plans on different new test tasks in the multitask setting.

Real world transfer

Below, we further illustrate generated videos given language instructions on unseen real images. Our approach is able to synthesize a diverse set of different behaviors which satisfy language instructions:

Using internet pre-training enables UniPi to synthesize videos of tasks not seen during training. In contrast, a model trained from scratch incorrectly generates plans of different tasks:

To evaluate the quality of videos generated by UniPi when pre-trained on non-robot data, we use the Fréchet Inception Distance (FID) and Fréchet Video Distance (FVD) metrics. We used Contrastive Language-Image Pre-training scores (CLIPScores) to measure the language-image alignment. We demonstrate that pre-trained UniPi achieves significantly higher FID and FVD scores and a better CLIPScore compared to UniPi without pre-training, suggesting that pre-training on non-robot data helps with generating plans for robots. We report the CLIPScore, FID, and VID scores for UniPi trained on Bridge data, with and without pre-training:

Model (24x40)       CLIPScore ↑       FID ↓       FVD ↓      
No pre-training       24.43 ± 0.04       17.75 ± 0.56       288.02 ± 10.45      
Pre-trained       24.54 ± 0.03       14.54 ± 0.57       264.66 ± 13.64      

Using existing internet data improves video plan predictions under all metrics considered.

The future of large-scale generative models for decision making

The positive results of UniPi point to the broader direction of using generative models and the wealth of data on the internet as powerful tools to learn general-purpose decision making systems. UniPi is only one step towards what generative models can bring to decision making. Other examples include using generative foundation models to provide photorealistic or linguistic simulators of the world in which artificial agents can be trained indefinitely. Generative models as agents can also learn to interact with complex environments such as the internet, so that much broader and more complex tasks can eventually be automated. We look forward to future research in applying internet-scale foundation models to multi-environment and multi-embodiment settings.


We’d like to thank all remaining authors of the paper including Bo Dai, Hanjun Dai, Ofir Nachum, Joshua B. Tenenbaum, Dale Schuurmans, and Pieter Abbeel. We would like to thank George Tucker, Douglas Eck, and Vincent Vanhoucke for the feedback on this post and on the original paper.

Leveraging transfer learning for large scale differentially private image classification

Large deep learning models are becoming the workhorse of a variety of critical machine learning (ML) tasks. However, it has been shown that without any protection it is plausible for bad actors to attack a variety of models, across modalities, to reveal information from individual training examples. As such, it’s essential to protect against this sort of information leakage.

Differential privacy (DP) provides formal protection against an attacker who aims to extract information about the training data. The most popular method for DP training in deep learning is differentially private stochastic gradient descent (DP-SGD). The core recipe implements a common theme in DP: “fuzzing” an algorithm’s outputs with noise to obscure the contributions of any individual input.

In practice, DP training can be very expensive or even ineffective for very large models. Not only does the computational cost typically increase when requiring privacy guarantees, but the noise also increases proportionally. Given these challenges, there has recently been much interest in developing methods that enable efficient DP training. The goal is to develop simple and practical methods for producing high-quality large-scale private models.

The ImageNet classification benchmark is an effective test bed for this goal because 1) it is a challenging task even in the non-private setting, that requires sufficiently large models to successfully classify large numbers of varied images and 2) it is a public, open-source dataset, which other researchers can access and use for collaboration. With this approach, researchers may simulate a practical situation where a large model is required to train on private data with DP guarantees.

To that end, today we discuss improvements we’ve made in training high-utility, large-scale private models. First, in “Large-Scale Transfer Learning for Differentially Private Image Classification”, we share strong results on the challenging task of image classification on the ImageNet-1k dataset with DP constraints. We show that with a combination of large-scale transfer learning and carefully chosen hyperparameters it is indeed possible to significantly reduce the gap between private and non-private performance even on challenging tasks and high-dimensional models. Then in “Differentially Private Image Classification from Features”, we further show that privately fine-tuning just the last layer of pre-trained model with more advanced optimization algorithms improves the performance even further, leading to new state-of-the-art DP results across a variety of popular image classification benchmarks, including ImageNet-1k. To encourage further development in this direction and enable other researchers to verify our findings, we are also releasing the associated source code.

Transfer learning and differential privacy

The main idea behind transfer learning is to reuse the knowledge gained from solving one problem and then apply it to a related problem. This is especially useful when there is limited or low-quality data available for the target problem as it allows us to leverage the knowledge gained from a larger and more diverse public dataset.

In the context of DP, transfer learning has emerged as a promising technique to improve the accuracy of private models, by leveraging knowledge learned from pre-training tasks. For example, if a model has already been trained on a large public dataset for a similar privacy-sensitive task, it can be fine-tuned on a smaller and more specific dataset for the target DP task. More specifically, one first pre-trains a model on a large dataset with no privacy concerns, and then privately fine-tunes the model on the sensitive dataset. In our work, we improve the effectiveness of DP transfer learning and illustrate it by simulating private training on publicly available datasets, namely ImageNet-1k, CIFAR-100, and CIFAR-10.

Better pre-training improves DP performance

To start exploring how transfer learning can be effective for differentially private image classification tasks, we carefully examined hyperparameters affecting DP performance. Surprisingly, we found that with carefully chosen hyperparameters (e.g., initializing the last layer to zero and choosing large batch sizes), privately fine-tuning just the last layer of a pre-trained model yields significant improvements over the baseline. Training just the last layer also significantly improves the cost-utility ratio of training a high-quality image classification model with DP.

As shown below, we compare the performance on ImageNet of the best hyperparameter recommendations both with and without privacy and across a variety of model and pre-training dataset sizes. We find that scaling the model and using a larger pre-training dataset decreases the gap in accuracy coming from the addition of the privacy guarantee. Typically, privacy guarantees of a system are characterized by a positive parameter ε, with smaller ε corresponding to better privacy. In the following figure, we use the privacy guarantee of ε = 10.

Comparing our best models with and without privacy on ImageNet across model and pre-training dataset sizes. The X-axis shows the different Vision Transformer models we used for this study in ascending order of model size from left to right. We used JFT-300M to pretrain B/16, L/16 and H/14 models, JFT-4B (a larger version of JFT-3B) to pretrain H/14-4b and JFT-3B to pretrain G/14-3b. We do this in order to study the effectiveness of jointly scaling the model and pre-training dataset (JFT-3B or 4B). The Y-axis shows the Top-1 accuracy on ImageNet-1k test set once the model is finetuned (in the private or non-private way) with the ImageNet-1k training set. We consistently see that the scaling of the model and the pre-training dataset size decreases the gap in accuracy coming from the addition of the privacy guarantee of ε = 10.

Better optimizers improve DP performance

Somewhat surprisingly, we found that privately training just the last layer of a pre-trained model provides the best utility with DP. While past studies [1, 2, 3] largely relied on using first-order differentially private training algorithms like DP-SGD for training large models, in the specific case of privately learning just the last layer from features, we observe that computational burden is often low enough to allow for more sophisticated optimization schemes, including second-order methods (e.g., Newton or Quasi-Newton methods), which can be more accurate but also more computationally expensive.

In “Differentially Private Image Classification from Features”, we systematically explore the effect of loss functions and optimization algorithms. We find that while the commonly used logistic regression performs better than linear regression in the non-private setting, the situation is reversed in the private setting: least-squares linear regression is much more effective than logistic regression from both a privacy and computational standpoint for typical range of ε values ([1, 10]), and even more effective for stricter epsilon values (ε < 1).

We further explore using DP Newton’s method to solve logistic regression. We find that this is still outperformed by DP linear regression in the high privacy regime. Indeed, Newton's method involves computing a Hessian (a matrix that captures second-order information), and making this matrix differentially private requires adding far more noise in logistic regression than in linear regression, which has a highly structured Hessian.

Building on this observation, we introduce a method that we call differentially private SGD with feature covariance (DP-FC), where we simply replace the Hessian in logistic regression with privatized feature covariance. Since feature covariance only depends on the inputs (and neither on model parameters nor class labels), we are able to share it across classes and training iterations, thus greatly reducing the amount of noise that needs to be added to protect it. This allows us to combine the benefits of using logistic regression with the efficient privacy protection of linear regression, leading to improved privacy-utility trade-off.

With DP-FC, we surpass previous state-of-the-art results considerably on three private image classification benchmarks, namely ImageNet-1k, CIFAR-10 and CIFAR-100, just by performing DP fine-tuning on features extracted from a powerful pre-trained model.

Comparison of top-1 accuracies (Y-axis) with private fine-tuning using DP-FC method on all three datasets across a range of ε (X-axis). We observe that better pre-training helps even more for lower values of ε (stricter privacy guarantee).


We demonstrate that large-scale pre-training on a public dataset is an effective strategy for obtaining good results when fine-tuned privately. Moreover, scaling both model size and pre-training dataset improves performance of the private model and narrows the quality gap compared to the non-private model. We further provide strategies to effectively use transfer learning for DP. Note that this work has several limitations worth considering — most importantly our approach relies on the availability of a large and trustworthy public dataset, which can be challenging to source and vet. We hope that our work is useful for training large models with meaningful privacy guarantees!


In addition to the authors of this blogpost, this research was conducted by Abhradeep Thakurta, Alex Kurakin and Ashok Cutkosky. We are also grateful to the developers of Jax, Flax, and Scenic libraries. Specifically, we would like to thank Mostafa Dehghani for helping us with Scenic and high-performance vision baselines and Lucas Beyer for help with deduping the JFT data. We are also grateful to Li Zhang, Emil Praun, Andreas Terzis, Shuang Song, Pierre Tholoniat, Roxana Geambasu, and Steve Chien for stimulating discussions on differential privacy throughout the project. Additionally, we thank anonymous reviewers, Gautam Kamath and Varun Kanade for helpful feedback throughout the publication process. Finally, we would like to thank John Anderson and Corinna Cortes from Google Research, Borja Balle, Soham De, Sam Smith, Leonard Berrada, and Jamie Hayes from DeepMind for generous feedback.

Training Generalist Agents with Multi-Game Decision Transformers

Current deep reinforcement learning (RL) methods can train specialist artificial agents that excel at decision-making on various individual tasks in specific environments, such as Go or StarCraft. However, little progress has been made to extend these results to generalist agents that would not only be capable of performing many different tasks, but also upon a variety of environments with potentially distinct embodiments.

Looking across recent progress in the fields of natural language processing, vision, and generative models (such as PaLM, Imagen, and Flamingo), we see that breakthroughs in making general-purpose models are often achieved by scaling up Transformer-based models and training them on large and semantically diverse datasets. It is natural to wonder, can a similar strategy be used in building generalist agents for sequential decision making? Can such models also enable fast adaptation to new tasks, similar to PaLM and Flamingo?

As an initial step to answer these questions, in our recent paper “Multi-Game Decision Transformers” we explore how to build a generalist agent to play many video games simultaneously. Our model trains an agent that can play 41 Atari games simultaneously at close-to-human performance and that can also be quickly adapted to new games via fine-tuning. This approach significantly improves upon the few existing alternatives to learning multi-game agents, such as temporal difference (TD) learning or behavioral cloning (BC).

A Multi-Game Decision Transformer (MGDT) can play multiple games at desired level of competency from training on a range of trajectories spanning all levels of expertise.

Don’t Optimize for Return, Just Ask for Optimality
In reinforcement learning, reward refers to the incentive signals that are relevant to completing a task, and return refers to cumulative rewards in a course of interactions between an agent and its surrounding environment. Traditional deep reinforcement learning agents (DQN, SimPLe, Dreamer, etc) are trained to optimize decisions to achieve the optimal return. At every time step, an agent observes the environment (some also consider the interactions that happened in the past) and decides what action to take to help itself achieve a higher return magnitude in future interactions.

In this work, we use Decision Transformers as our backbone approach to training an RL agent. A Decision Transformer is a sequence model that predicts future actions by considering past interactions between an agent and the surrounding environment, and (most importantly) a desired return to be achieved in future interactions. Instead of learning a policy to achieve high return magnitude as in traditional reinforcement learning, Decision Transformers map diverse experiences, ranging from expert-level to beginner-level, to their corresponding return magnitude during training. The idea is that training an agent on a range of experiences (from beginner to expert level) exposes the model to a wider range of variations in gameplay, which in turn helps it extract useful rules of gameplay that allow it to succeed under any circumstance. So during inference, the Decision Transformer can achieve any return value in the range it has seen during training, including the optimal return.

But, how do you know if a return is both optimal and stably achievable in a given environment? Previous applications of Decision Transformers relied on customized definitions of the desired return for each individual task, which required manually defining a plausible and informative range of scalar values that are appropriately interpretable signals for each specific game — a task that is non-trivial and rather unscalable. To address this issue, we instead model a distribution of return magnitudes based on past interactions with the environment during training. At inference time, we simply add an optimality bias that increases the probability of generating actions that are associated with higher returns.

To more comprehensively capture spatial-temporal patterns of agent-environment interactions, we also modified the Decision Transformer architecture to consider image patches instead of a global image representation. Patches allow the model to focus on local dynamics, which helps model game specific information in further detail.

These pieces together give us the backbone of Multi-Game Decision Transformers:

Each observation image is divided into a set of M patches of pixels which are denoted O. Return R, action a, and reward r follows these image patches in each input casual sequence. A Decision Transformer is trained to predict the next input (except for the image patches) to establish causality.

Training a Multi-Game Decision Transformer to Play 41 Games at Once
We train one Decision Transformer agent on a large (~1B) and broad set of gameplay experiences from 41 Atari games. In our experiments, this agent, which we call the Multi-Game Decision Transformer (MGDT), clearly outperforms existing reinforcement learning and behavioral cloning methods — by almost 2 times — on learning to play 41 games simultaneously and performs near human-level competency (100% in the following figure corresponds to the level of human gameplay). These results hold when comparing across training methods in both settings where a policy must be learned from static datasets (offline) as well as those where new data can be gathered from interacting with the environment (online).

Each bar is a combined score across 41 games, where 100% indicates human-level performance. Each blue bar is from a model trained on 41 games simultaneously, whereas each gray bar is from 41 specialist agents. Multi-Game Decision Transformer achieves human-level performance, significantly better than other multi-game agents, even comparable to specialist agents.

This result indicates that Decision Transformers are well-suited for multi-task, multi-environment, and multi-embodiment agents.

A concurrent work, “A Generalist Agent”, shows a similar result, demonstrating that large transformer-based sequence models can memorize expert behaviors very well across many more environments. In addition, their work and our work have nicely complementary findings: They show it’s possible to train across a wide range of environments beyond Atari games, while we show it’s possible and useful to train across a wide range of experiences.

In addition to the performance shown above, empirically we found that MGDT trained on a wide variety of experience is better than MDGT trained only on expert-level demonstrations or simply cloning demonstration behaviors.

Scaling Up Multi-Game Model Size to Achieve Better Performance
Argurably, scale has become the main driving force in many recent machine learning breakthroughs, and it is usually achieved by increasing the number of parameters in a transformer-based model. Our observation on Multi-Game Decision Transformers is similar: the performance increases predictably with larger model size. In particular, its performance appears to have not yet hit a ceiling, and compared to other learning systems performance gains are more significant with increases in model size.

Performance of Multi-Game Decision Transformer (shown by the blue line) increases predictably with larger model size, whereas other models do not.

Pre-trained Multi-Game Decision Transformers Are Fast Learners
Another benefit of MGDTs is that they can learn how to play a new game from very few gameplay demonstrations (which don’t need to all be expert-level). In that sense, MGDTs can be considered pre-trained models capable of being fine-tuned rapidly on small new gameplay data. Compared with other popular pre-training methods, it clearly shows consistent advantages in obtaining higher scores.

Multi-Game Decision Transformer pre-training (DT pre-training, shown in light blue) demonstrates consistent advantages over other popular models in adaptation to new tasks.

Where Is the Agent Looking?
In addition to the quantitative evaluation, it’s insightful (and fun) to visualize the agent’s behavior. By probing the attention heads, we find that the MGDT model consistently places weight in its field of view to areas of the observed images that contain meaningful game entities. We visualize the model’s attention when predicting the next action for various games and find it consistently attends to entities such as the agent’s on screen avatar, agent’s free movement space, non-agent objects, and key environment features. For example, in an interactive setting, having an accurate world model requires knowing how and when to focus on known objects (e.g., currently present obstacles) as well as expecting and/or planning over future unknowns (e.g., negative space). This diverse allocation of attention to many key components of each environment ultimately improves performance.

Here we can see the amount of weight the model places on each key asset of the game scene. Brighter red indicates more emphasis on that patch of pixels.

The Future of Large-Scale Generalist Agents
This work is an important step in demonstrating the possibility of training general-purpose agents across many environments, embodiments, and behavior styles. We have shown the benefit of increased scale on performance and the potential with further scaling. These findings seem to point to a generalization narrative similar to other domains like vision and language — we look forward to exploring the great potential of scaling data and learning from diverse experiences.

We look forward to future research towards developing performant agents for multi-environment and multi-embodiment settings. Our code and model checkpoints can soon be accessed here.

We’d like to thank all remaining authors of the paper including Igor Mordatch, Ofir Nachum Menjiao Yang, Lisa Lee, Daniel Freeman, Sergio Guadarrama, Ian Fischer, Eric Jang, Henryk Michalewski.

Reproducibility in Deep Learning and Smooth Activations

Ever queried a recommender system and found that the same search only a few moments later or on a different device yields very different results? This is not uncommon and can be frustrating if a person is looking for something specific. As a designer of such a system, it is also not uncommon for the metrics measured to change from design and testing to deployment, bringing into question the utility of the experimental testing phase. Some level of such irreproducibility can be expected as the world changes and new models are deployed. However, this also happens regularly as requests hit duplicates of the same model or models are being refreshed.

Lack of replicability, where researchers are unable to reproduce published results with a given model, has been identified as a challenge in the field of machine learning (ML). Irreproducibility is a related but more elusive problem, where multiple instances of a given model are trained on the same data under identical training conditions, but yield different results. Only recently has irreproducibility been identified as a difficult problem, but due to its complexity, theoretical studies to understand this problem are extremely rare.

In practice, deep network models are trained in highly parallelized and distributed environments. Nondeterminism in training from random initialization, parallelism, distributed training, data shuffling, quantization errors, hardware types, and more, combined with objectives with multiple local optima contribute to the problem of irreproducibility. Some of these factors, such as initialization, can be controlled, but it is impractical to control others. Optimization trajectories can diverge early in training by following training examples in the order seen, leading to very different models. Several recently published solutions [1, 2, 3] based on advanced combinations of ensembling, self-ensembling, and distillation can mitigate the problem, but usually at the cost of accuracy and increased complexity, maintenance and improvement costs.

In “Real World Large Scale Recommendation Systems Reproducibility and Smooth Activations”, we consider a different practical solution to this problem that does not incur the costs of other solutions, while still improving reproducibility and yielding higher model accuracy. We discover that the Rectified Linear Unit (ReLU), which is very popular as the nonlinearity function (i.e., activation function) used to transform values in neural networks, exacerbates the irreproducibility problem. On the other hand, we demonstrate that smooth activation functions, which have derivatives that are continuous for the whole domain, unlike those of ReLU, are able to substantially reduce irreproducibility levels. We then propose the Smooth reLU (SmeLU) activation function, which gives comparable reproducibility and accuracy benefits to other smooth activations but is much simpler.

The ReLU function (left) as function of the input signal, and its gradient (right) as function of the input.

Smooth Activations
An ML model attempts to learn the best model parameters that fit the training data by minimizing a loss, which can be imagined as a landscape with peaks and valleys, where the lowest point attains an optimal solution. For deep models, the landscape may consist of many such peaks and valleys. The activation function used by the model governs the shape of this landscape and how the model navigates it.

ReLU, which is not a smooth function, imposes an objective whose landscape is partitioned into many regions with multiple local minima, each providing different model predictions. With this landscape, the order in which updates are applied is a dominant factor in determining the optimization trajectory, providing a recipe for irreproducibility. Because of its non-continuous gradient, functions expressed by a ReLU network will contain sudden jumps in the gradient, which can occur internally in different layers of the deep network, affecting updates of different internal units, and are likely strong contributors to irreproducibility.

Suppose a sequence of model updates attempts to push the activation of some unit down from a positive value. The gradient of the ReLU function is 1 for positive unit values, so with every update it pushes the unit to become smaller and smaller (to the left in the panel above). At the point the activation of this unit crosses the threshold from a positive value to a negative one, the gradient suddenly changes from magnitude 1 to magnitude 0. Training attempts to keep moving the unit leftwards, but due to the 0 gradient, the unit cannot move further in that direction. Therefore, the model must resort to updating other units that can move.

We find that networks with smooth activations (e.g., GELU, Swish and Softplus) can be substantially more reproducible. They may exhibit a similar objective landscape, but with fewer regions, giving a model fewer opportunities to diverge. Unlike the sudden jumps with ReLU, for a unit with decreasing activations, the gradient gradually reduces to 0, which gives other units opportunities to adjust to the changing behavior. With equal initialization, moderate shuffling of training examples, and normalization of hidden layer outputs, smooth activations are able to increase the chances of converging to the same minimum. Very aggressive data shuffling, however, loses this advantage.

The rate that a smooth activation function transitions between output levels, i.e., its “smoothness”, can be adjusted. Sufficient smoothness leads to improved accuracy and reproducibility. Too much smoothness, though, approaches linear models with a corresponding degradation of model accuracy, thus losing the advantages of using a deep network.

Smooth activations (top) and their gradients (bottom) for different smoothness parameter values β as a function of the input values. β determines the width of the transition region between 0 and 1 gradients. For Swish and Softplus, a greater β gives a narrower region, for SmeLU, a greater β gives a wider region.

Smooth reLU (SmeLU)
Activations like GELU and Swish require complex hardware implementations to support exponential and logarithmic functions. Further, GELU must be computed numerically or approximated. These properties can make deployment error-prone, expensive, or slow. GELU and Swish are not monotonic (they start by slightly decreasing and then switch to increasing), which may interfere with interpretability (or identifiability), nor do they have a full stop or a clean slope 1 region, properties that simplify implementation and may aid in reproducibility. 

The Smooth reLU (SmeLU) activation function is designed as a simple function that addresses the concerns with other smooth activations. It connects a 0 slope on the left with a slope 1 line on the right through a quadratic middle region, constraining continuous gradients at the connection points (as an asymmetric version of a Huber loss function).

SmeLU can be viewed as a convolution of ReLU with a box. It provides a cheap and simple smooth solution that is comparable in reproducibility-accuracy tradeoffs to more computationally expensive and complex smooth activations. The figure below illustrates the transition of the loss (objective) surface as we gradually transition from a non-smooth ReLU to a smoother SmeLU. A transition of width 0 is the basic ReLU function for which the loss objective has many local minima. As the transition region widens (SmeLU), the loss surface becomes smoother. If the transition is too wide, i.e., too smooth, the benefit of using a deep network wanes and we approach the linear model solution — the objective surface flattens, potentially losing the ability of the network to express much information.

Loss surfaces (as functions of a 2D input) for two sample loss functions (middle and right) as the activation function’s transition region widens, going from from ReLU to an increasingly smoother SmeLU (left). The loss surface becomes smoother with increasing the smoothness of the SmeLU function.

SmeLU has benefited multiple systems, specifically recommendation systems, increasing their reproducibility by reducing, for example, recommendation swap rates. While the use of SmeLU results in accuracy improvements over ReLU, it also replaces other costly methods to address irreproducibility, such as ensembles, which mitigate irreproducibility at the cost of accuracy. Moreover, replacing ensembles in sparse recommendation systems reduces the need for multiple lookups of model parameters that are needed to generate an inference for each of the ensemble components. This substantially improves training and inference efficiency.

To illustrate the benefits of smooth activations, we plot the relative prediction difference (PD) as a function of change in some loss for the different activations. We define relative PD as the ratio between the absolute difference in predictions of two models and their expected prediction, averaged over all evaluation examples. We have observed that in large scale systems, it is sufficient, and inexpensive, to consider only two models for very consistent results.

The figure below shows curves on the PD-accuracy loss plane. For reproducibility, being lower on the curve is better, and for accuracy, being on the left is better. Smooth activations can yield a ballpark 50% reduction in PD relative to ReLU, while still potentially resulting in improved accuracy. SmeLU yields accuracy comparable to other smooth activations, but is more reproducible (lower PD) while still outperforming ReLU in accuracy.

Relative PD as a function of percentage change in the evaluation ranking loss, which measures how accurately items are ranked in a recommendation system (higher values indicate worse accuracy), for different activations.

Conclusion and Future Work
We demonstrated the problem of irreproducibility in real world practical systems, and how it affects users as well as system and model designers. While this particular issue has been given very little attention when trying to address the lack of replicability of research results, irreproducibility can be a critical problem. We demonstrated that a simple solution of using smooth activations can substantially reduce the problem without degrading other critical metrics like model accuracy. We demonstrate a new smooth activation function, SmeLU, which has the added benefits of mathematical simplicity and ease of implementation, and can be cheap and less error prone.

Understanding reproducibility, especially in deep networks, where objectives are not convex, is an open problem. An initial theoretical framework for the simpler convex case has recently been proposed, but more research must be done to gain a better understanding of this problem which will apply to practical systems that rely on deep networks.

We would like to thank Sergey Ioffe for early discussions about SmeLU; Lorenzo Coviello and Angel Yu for help in early adoptions of SmeLU; Shiv Venkataraman for sponsorship of the work; Claire Cui for discussion and support from the very beginning; Jeremiah Willcock, Tom Jablin, and Cliff Young for substantial implementation support; Yuyan Wang, Mahesh Sathiamoorthy, Myles Sussman, Li Wei, Kevin Regan, Steven Okamoto, Qiqi Yan, Todd Phillips, Ed Chi, Sunita Verna, and many many others for many discussions, and for integrations in many different systems; Matt Streeter and Yonghui Wu for feedback on the paper and this post; Tom Small for help with the illustrations in this post.

Extending Contrastive Learning to the Supervised Setting

In recent years, self-supervised representation learning, which is used in a variety of image and video tasks, has significantly advanced due to the application of contrastive learning. These contrastive learning approaches typically teach a model to pull together the representations of a target image (a.k.a., the “anchor”) and a matching (“positive”) image in embedding space, while also pushing apart the anchor from many non-matching (“negative”) images. Because labels are assumed to be unavailable in self-supervised learning, the positive is often an augmentation of the anchor, and the negatives are chosen to be the other samples from the training minibatch. However, because of this random sampling, false negatives, i.e., negatives generated from samples of the same class as the anchor, can cause a degradation in the representation quality. Furthermore, determining the optimal method to generate positives is still an area of active research.

In contrast to the self-supervised approach, a fully-supervised approach could use labeled data to generate positives from existing same-class examples, providing more variability in pretraining than could typically be achieved by simply augmenting the anchor. However, very little work has been done to successfully apply contrastive learning in the fully-supervised domain.

In “Supervised Contrastive Learning”, presented at NeurIPS 2020, we propose a novel loss function, called SupCon, that bridges the gap between self-supervised learning and fully supervised learning and enables contrastive learning to be applied in the supervised setting. Leveraging labeled data, SupCon encourages normalized embeddings from the same class to be pulled closer together, while embeddings from different classes are pushed apart. This simplifies the process of positive selection, while avoiding potential false negatives. Because it accommodates multiple positives per anchor, this approach results in an improved selection of positive examples that are more varied, while still containing semantically relevant information. SupCon also allows label information to play an active role in representation learning rather than restricting it to be used only in downstream training, as is the case for conventional contrastive learning. To the best of our knowledge, this is the first contrastive loss to consistently perform better on large-scale image classification problems than the common approach of using cross-entropy loss to train the model directly. Importantly, SupCon is straightforward to implement and stable to train, provides consistent improvement to top-1 accuracy for a number of datasets and architectures (including Transformer architectures), and is robust to image corruptions and hyperparameter variations.

Self-supervised (left) vs supervised (right) contrastive losses: The self-supervised contrastive loss contrasts a single positive for each anchor (i.e., an augmented version of the same image) against a set of negatives consisting of the entire remainder of the minibatch. The supervised contrastive loss considered in this paper, however, contrasts the set of all samples from the same class as positives against the negatives from the remainder of the batch.

The Supervised Contrastive Learning Framework
SupCon can be seen as a generalization of both the SimCLR and N-pair losses — the former uses positives generated from the same sample as that of the anchor, and the latter uses positives generated from different samples by exploiting known class labels. The use of many positives and many negatives for each anchor allows SupCon to achieve state-of-the-art performance without the need for hard negative mining (i.e., searching for negatives similar to the anchor), which can be difficult to tune properly.

SupCon subsumes multiple losses from the literature and is a generalization of the SimCLR and N-Pair losses.

This method is structurally similar to those used in self-supervised contrastive learning, with modifications for supervised classification. Given an input batch of data, we first apply data augmentation twice to obtain two copies, or “views,” of each sample in the batch (though one could create and use any number of augmented views). Both copies are forward propagated through an encoder network, and the resulting embedding is then L2-normalized. Following standard practice, the representation is further propagated through an optional projection network to help identify meaningful features. The supervised contrastive loss is computed on the normalized outputs of the projection network. Positives for an anchor consist of the representations originating from the same batch instance as the anchor or from other instances with the same label as the anchor; the negatives are then all remaining instances. To measure performance on downstream tasks, we train a linear classifier on top of the frozen representations.

Cross-entropy, self-supervised contrastive loss and supervised contrastive loss Left: The cross-entropy loss uses labels and a softmax loss to train a classifier. Middle: The self-supervised contrastive loss uses a contrastive loss and data augmentations to learn representations. Right: The supervised contrastive loss also learns representations using a contrastive loss, but uses label information to sample positives in addition to augmentations of the same image.

Key Findings
SupCon consistently boosts top-1 accuracy compared to cross-entropy, margin classifiers (with use of labels), and self-supervised contrastive learning techniques on CIFAR-10 and CIFAR-100 and ImageNet datasets. With SupCon, we achieve excellent top-1 accuracy on the ImageNet dataset with the ResNet-50 and ResNet-200 architectures. On ResNet-200, we achieve a top-1 accuracy of 81.4%, which is a 0.8% improvement over the state-of-the-art cross-entropy loss using the same architecture (which represents a significant advance for ImageNet). We also compared cross-entropy and SupCon on a Transformer-based ViT-B/16 model and found a consistent improvement over cross-entropy (77.8% versus 76% for ImageNet; 92.6% versus 91.6% for CIFAR-10) under the same data augmentation regime (without any higher-resolution fine-tuning).

The SupCon loss consistently outperforms cross-entropy with standard data augmentation strategies (AutoAugment, RandAugment and CutMix). We show top-1 accuracy for ImageNet, on ResNet-50, ResNet-101 and ResNet200.

We also demonstrate analytically that the gradient of our loss function encourages learning from hard positives and hard negatives. The gradient contributions from hard positives/negatives are large while those for easy positives/negatives are small. This implicit property allows the contrastive loss to sidestep the need for explicit hard mining, which is a delicate but critical part of many losses, such as triplet loss. See the supplementary material of our paper for a full derivation.

SupCon is also more robust to natural corruptions, such as noise, blur and JPEG compression. The mean Corruption Error (mCE) measures the average degradation in performance compared to the benchmark ImageNet-C dataset. The SupCon models have lower mCE values across different corruptions compared to cross-entropy models, showing increased robustness.

We show empirically that the SupCon loss is less sensitive than cross-entropy to a range of hyperparameters. Across changes in augmentations, optimizers, and learning rates, we observe significantly lower variance in the output of the contrastive loss. Moreover, applying different batch sizes while holding all other hyperparameters constant results in consistently better top-1 accuracy of SupCon to that of cross-entropy at each batch size.

Accuracy of cross-entropy and supervised contrastive loss as a function of hyperparameters and training data size, measured on ImageNet with a ResNet-50 encoder. Left: Boxplot showing Top-1 accuracy vs changes in augmentation, optimizer and learning rates. SupCon yields more consistent results across variations in each, which is useful when the best strategies are unknown a priori. Right: Top-1 accuracy as a function of batch size shows both losses benefit from larger batch sizes while SupCon has higher Top-1 accuracy, even when trained with small batch sizes.
Accuracy of supervised contrastive loss as a function of training duration and the temperature hyperparameter, measured on ImageNet with a ResNet-50 encoder. Left: Top-1 accuracy as a function of SupCon pre-training epochs. Right: Top-1 accuracy as a function of temperature during the pre-training stage for SupCon. Temperature is an important hyperparameter in contrastive learning and reducing sensitivity to temperature is desirable.

Broader Impact and Next Steps
This work provides a technical advancement in the field of supervised classification. Supervised contrastive learning can improve both the accuracy and robustness of classifiers with minimal complexity. The classic cross-entropy loss can be seen as a special case of SupCon where the views correspond to the images and the learned embeddings in the final linear layer corresponding to the labels. We note that SupCon benefits from large batch sizes, and being able to train the models on smaller batches is an important topic for future research.

Our Github repository includes Tensorflow code to train the models in the paper. Our pre-trained models are also released on TF-Hub.

The NeurIPS paper was jointly co-authored with Prannay Khosla, Piotr Teterwak, Chen Wang, Aaron Sarna, Yonglong Tian, Phillip Isola, Aaron Maschinot, Ce Liu, and Dilip Krishnan. Special thanks to Jenny Huang for leading the writing process for this blogpost.

Learning to Smell: Using Deep Learning to Predict the Olfactory Properties of Molecules

Smell is a sense shared by an incredible range of living organisms, and plays a critical role in how they analyze and react to the world. For humans, our sense of smell is tied to our ability to enjoy food and can also trigger vivid memories. Smell allows us to appreciate all of the fragrances that abound in our everyday lives, be they the proverbial roses, a batch of freshly baked cookies, or a favorite perfume. Yet despite its importance, smell has not received the same level of attention from machine learning researchers as have vision and hearing.

Odor perception in humans is the result of the activation of 400 different types of olfactory receptors (ORs), expressed in 1 million olfactory sensory neurons (OSNs), in a small patch of tissue called the olfactory epithelium. These OSNs send signals to the olfactory bulb, and then to further structures in the brain. Based on analogous advances in deep learning for sight and sound, it should be possible to directly predict the end sensory result of an input molecule, even without knowing the intricate details of all the systems involved. Solving the odor prediction problem would aid in discovering new synthetic odorants, thereby reducing the ecological impact of harvesting natural products. Inspection of the resulting olfactory models may even lead to new insights into the biology of smell.

Small odorant molecules are the most basic building blocks of flavors and fragrances, and therefore represent the simplest version of the odor prediction problem. Yet each molecule can have multiple odor descriptors. Vanillin, for example, has descriptors such as sweet, vanilla, creamy, and chocolate, with some notes being more apparent than others. So odor prediction is also a multi-label classification problem.

In “Machine Learning for Scent: Learning Generalizable Perceptual Representations of Small Molecules”, we leverage graph neural networks (GNNs), a kind of deep neural network designed to operate on graphs as input, to directly predict the odor descriptors for individual molecules, without using any handcrafted rules. We demonstrate that this approach yields significantly improved performance in odor prediction compared to current state-of-the-art and is a promising direction for future research.

Graph Neural Networks for Odor Prediction
Since molecules are analogous to graphs, with atoms forming the vertices and bonds forming the edges, GNNs are the natural model of choice for their understanding. But how does one translate the structure of a molecule into a graph representation? Initially, every node in the graph is represented as a vector, using any preferred featurization — atom identity, atom charge, etc. Then, in a series of message passing steps, every node broadcasts its current vector value to each of its neighbors. An update function then takes the collection of vectors sent to it, and generates an updated vector value. This process can be repeated many times, until finally all of the nodes in the graph are summarized into a single vector via summing or averaging. That single vector, representing the entire molecule, can then be passed into a fully connected network as a learned molecular featurization. This network outputs a prediction for odor descriptors, as provided by perfume experts.
Each node is represented as a vector, and each entry in the vector initially encodes some atomic-level information.
For each node we look at adjacent nodes and collect their information, which is then transformed with a neural network into new information for the centered node. This procedure is performed iteratively. Other variants of GNNs utilize edge and graph-level information.
Illustration of a GNN for odor prediction. We translate the structure of molecules into graphs that are fed into GNN layers to learn a better representation of the nodes. These nodes are reduced into a single vector and passed into a neural network that is used to predict multiple odor descriptors.
This representation doesn’t know anything about spatial positions of atoms, and so it can’t distinguish stereoisomers, molecules made of the same atoms but in slightly different configurations that can smell different, such as (R)- and (S)-carvone. Nevertheless, we have found that even without distinguishing stereoisomers, in practice it is still possible to predict odor quite well.

For odor prediction, GNNs consistently demonstrate improved performance compared to previous state-of-the-art methods, such as random forests, which do not directly encode graph structure. The magnitude of the improvement depends on which odor one tries to predict.
Example of the performance of a GNN on odor descriptors against a strong baseline, as measured by the AUROC score. Example odor descriptors are picked randomly. Closer to 1.0 means better. In the majority of cases GNNs outperform the field-standard baseline substantially, with similar performance seen against other metrics (e.g., AUPRC, recall, precision).
Learning from the Model, and Extending It to Other Tasks
In addition to predicting odor descriptors, GNNs can be applied to other olfaction tasks. For example, take the case of classifying new or refined odor descriptors using only limited data. For each molecule, we extract a learned representation from an intermediate layer of the model that is optimized for our odor descriptors, which we call an “odor embedding”. One can think of this as an olfaction version of a color space, like RGB or CMYK. To see if this odor embedding is useful for predicting related but different tasks, we designed experiments that test our learned embedding on related tasks for which it was not originally designed. We then compared the performance of our odor embedding representation to a common chemoinformatic representation that encodes structural information of a molecule, but is agnostic to odor and found that the odor embedding generalized to several challenging new tasks, even matching state-of-the-art on some.
2D snapshot of our embedding space with some example odors highlighted. Left: Each odor is clustered in its own space. Right: The hierarchical nature of the odor descriptor. Shaded and contoured areas are computed with a kernel-density estimate of the embeddings.
Future Work
Within the realm of machine learning, smell remains the most elusive of the senses, and we’re excited to continue doing a small part to shed light on it through further fundamental research. The possibilities for future research are numerous, and touch on everything from designing new olfactory molecules that are cheaper and more sustainably produced, to digitizing scent, or even one day giving those without a sense of smell access to roses (and, unfortunately, also rotten eggs). We hope to also bring this problem to the attention of more of the machine learning world through the eventual creation and sharing of high-quality, open datasets.

This early research is the result of the work and advisement of a team of talented researchers and engineers in Google Brain — Benjamin Sanchez-Lengeling, Jennifer Wei, Brian Lee, Emily Reif, Carey Radebaugh, Max Bileschi, Yoni Halpern, and D. Sculley. We are delighted to have collaborated on this work with Richard Gerkin at ASU and Alán Aspuru-Guzik at the University of Toronto. We are of course building on an enormous amount of prior work, and have benefitted particularly from work by Justin Gilmer, George Dahl and others on fundamental methodology in GNNs, among many other works in neuroscience, statistics and chemistry. We are also grateful to helpful comments from Steven Kearnes, David Belanger, Joel Mainland, and Emily Mayhew.

Accurate Online Speaker Diarization with Supervised Learning

Speaker diarization, the process of partitioning an audio stream with multiple people into homogeneous segments associated with each individual, is an important part of speech recognition systems. By solving the problem of “who spoke when”, speaker diarization has applications in many important scenarios, such as understanding medical conversations, video captioning and more. However, training these systems with supervised learning methods is challenging — unlike standard supervised classification tasks, a robust diarization model requires the ability to associate new individuals with distinct speech segments that weren't involved in training. Importantly, this limits the quality of both online and offline diarization systems. Online systems usually suffer more, since they require diarization results in real time.
Online speaker diarization on streaming audio input. Different colors in the bottom axis indicate different speakers.
In “Fully Supervised Speaker Diarization”, we describe a new model that seeks to make use of supervised speaker labels in a more effective manner. Here “fully” implies that all components in the speaker diarization system, including the estimation of the number of speakers, are trained in supervised ways, so that they can benefit from increasing the amount of labeled data available. On the NIST SRE 2000 CALLHOME benchmark, our diarization error rate (DER) is as low as 7.6%, compared to 8.8% DER from our previous clustering-based method, and 9.9% from deep neural network embedding methods. Moreover, our method achieves this lower error rate based on online decoding, making it specifically suitable for real-time applications. As such we are open sourcing the core algorithms in our paper to accelerate more research along this direction.

Clustering versus Interleaved-state RNN
Modern speaker diarization systems are usually based on clustering algorithms such as k-means or spectral clustering. Since these clustering methods are unsupervised, they could not make good use of the supervised speaker labels available in data. Moreover, online clustering algorithms usually have worse quality in real-time diarization applications with streaming audio inputs. The key difference between our model and common clustering algorithms is that in our method, all speakers’ embeddings are modeled by a parameter-sharing recurrent neural network (RNN), and we distinguish different speakers using different RNN states, interleaved in the time domain.

To understand how this works, consider the example below in which there are four possible speakers: blue, yellow, pink and green (this is arbitrary, and in fact there may be more — our model uses the Chinese restaurant process to accommodate the unknown number of speakers). Each speaker starts with its own RNN instance (with a common initial state shared among all speakers) and keeps updating the RNN state given the new embeddings from this speaker. In the example below, the blue speaker keeps updating its RNN state until a different speaker, yellow, comes in. If blue speaks again later, it resumes updating its RNN state. (This is just one of the possibilities for speech segment y7 in the figure below. If new speaker green enters, it will start with a new RNN instance.)
The generative process of our model. Colors indicate labels for speaker segments.
Representing speakers as RNN states enables us to learn the high-level knowledge shared across different speakers and utterances using RNN parameters, and this promises the usefulness of more labeled data. In contrast, common clustering algorithms almost always work with each single utterance independently, making it difficult to benefit from a large amount of labeled data.

The upshot of all this is that given time-stamped speaker labels (i.e. we know who spoke when), we can train the model with standard stochastic gradient descent algorithms. A trained model can be used for speaker diarization on new utterances from unheard speakers. Furthermore, the use of online decoding makes it more suitable for latency-sensitive applications.

Future Work
Although we've already achieved impressive diarization performance with this system, there are still many exciting directions we are currently exploring. First, we are refining our model so it can easily integrate contextual information to perform offline decoding. This will likely further reduce the DER, which is more useful for latency-insensitive applications. Second, we would like to model acoustic features directly instead of using d-vectors. In this way, the entire speaker diarization system can be trained in an end-to-end way.

To learn more about this work, please see our paper. To download the core algorithm of this system, please visit the Github page.

This work was done as a close collaboration between Google AI and Speech & Assistant teams. Contributors include Aonan Zhang (intern), Quan Wang, Zhengyao Zhu and Chong Wang.

