Tag Archives: Neural Networks

AutoBNN: Probabilistic time series forecasting with compositional bayesian neural networks

Time series problems are ubiquitous, from forecasting weather and traffic patterns to understanding economic trends. Bayesian approaches start with an assumption about the data's patterns (prior probability), collecting evidence (e.g., new time series data), and continuously updating that assumption to form a posterior probability distribution. Traditional Bayesian approaches like Gaussian processes (GPs) and Structural Time Series are extensively used for modeling time series data, e.g., the commonly used Mauna Loa CO2 dataset. However, they often rely on domain experts to painstakingly select appropriate model components and may be computationally expensive. Alternatives such as neural networks lack interpretability, making it difficult to understand how they generate forecasts, and don't produce reliable confidence intervals.

To that end, we introduce AutoBNN, a new open-source package written in JAX. AutoBNN automates the discovery of interpretable time series forecasting models, provides high-quality uncertainty estimates, and scales effectively for use on large datasets. We describe how AutoBNN combines the interpretability of traditional probabilistic approaches with the scalability and flexibility of neural networks.


AutoBNN is based on a line of research that over the past decade has yielded improved predictive accuracy by modeling time series using GPs with learned kernel structures. The kernel function of a GP encodes assumptions about the function being modeled, such as the presence of trends, periodicity or noise. With learned GP kernels, the kernel function is defined compositionally: it is either a base kernel (such as Linear, Quadratic, Periodic, Matérn or ExponentiatedQuadratic) or a composite that combines two or more kernel functions using operators such as Addition, Multiplication, or ChangePoint. This compositional kernel structure serves two related purposes. First, it is simple enough that a user who is an expert about their data, but not necessarily about GPs, can construct a reasonable prior for their time series. Second, techniques like Sequential Monte Carlo can be used for discrete searches over small structures and can output interpretable results.

AutoBNN improves upon these ideas, replacing the GP with Bayesian neural networks (BNNs) while retaining the compositional kernel structure. A BNN is a neural network with a probability distribution over weights rather than a fixed set of weights. This induces a distribution over outputs, capturing uncertainty in the predictions. BNNs bring the following advantages over GPs: First, training large GPs is computationally expensive, and traditional training algorithms scale as the cube of the number of data points in the time series. In contrast, for a fixed width, training a BNN will often be approximately linear in the number of data points. Second, BNNs lend themselves better to GPU and TPU hardware acceleration than GP training operations. Third, compositional BNNs can be easily combined with traditional deep BNNs, which have the ability to do feature discovery. One could imagine "hybrid" architectures, in which users specify a top-level structure of Add(Linear, Periodic, Deep), and the deep BNN is left to learn the contributions from potentially high-dimensional covariate information.

How might one translate a GP with compositional kernels into a BNN then? A single layer neural network will typically converge to a GP as the number of neurons (or "width") goes to infinity. More recently, researchers have discovered a correspondence in the other direction — many popular GP kernels (such as Matern, ExponentiatedQuadratic, Polynomial or Periodic) can be obtained as infinite-width BNNs with appropriately chosen activation functions and weight distributions. Furthermore, these BNNs remain close to the corresponding GP even when the width is very much less than infinite. For example, the figures below show the difference in the covariance between pairs of observations, and regression results of the true GPs and their corresponding width-10 neural network versions.

Comparison of Gram matrices between true GP kernels (top row) and their width 10 neural network approximations (bottom row).
Comparison of regression results between true GP kernels (top row) and their width 10 neural network approximations (bottom row).

Finally, the translation is completed with BNN analogues of the Addition and Multiplication operators over GPs, and input warping to produce periodic kernels. BNN addition is straightforwardly given by adding the outputs of the component BNNs. BNN multiplication is achieved by multiplying the activations of the hidden layers of the BNNs and then applying a shared dense layer. We are therefore limited to only multiplying BNNs with the same hidden width.

Using AutoBNN

The AutoBNN package is available within Tensorflow Probability. It is implemented in JAX and uses the flax.linen neural network library. It implements all of the base kernels and operators discussed so far (Linear, Quadratic, Matern, ExponentiatedQuadratic, Periodic, Addition, Multiplication) plus one new kernel and three new operators:

  • a OneLayer kernel, a single hidden layer ReLU BNN,
  • a ChangePoint operator that allows smoothly switching between two kernels,
  • a LearnableChangePoint operator which is the same as ChangePoint except position and slope are given prior distributions and can be learnt from the data, and
  • a WeightedSum operator.

WeightedSum combines two or more BNNs with learnable mixing weights, where the learnable weights follow a Dirichlet prior. By default, a flat Dirichlet distribution with concentration 1.0 is used.

WeightedSums allow a "soft" version of structure discovery, i.e., training a linear combination of many possible models at once. In contrast to structure discovery with discrete structures, such as in AutoGP, this allows us to use standard gradient methods to learn structures, rather than using expensive discrete optimization. Instead of evaluating potential combinatorial structures in series, WeightedSum allows us to evaluate them in parallel.

To easily enable exploration, AutoBNN defines a number of model structures that contain either top-level or internal WeightedSums. The names of these models can be used as the first parameter in any of the estimator constructors, and include things like sum_of_stumps (the WeightedSum over all the base kernels) and sum_of_shallow (which adds all possible combinations of base kernels with all operators).

Illustration of the sum_of_stumps model. The bars in the top row show the amount by which each base kernel contributes, and the bottom row shows the function represented by the base kernel. The resulting weighted sum is shown on the right.

The figure below demonstrates the technique of structure discovery on the N374 (a time series of yearly financial data starting from 1949) from the M3 dataset. The six base structures were ExponentiatedQuadratic (which is the same as the Radial Basis Function kernel, or RBF for short), Matern, Linear, Quadratic, OneLayer and Periodic kernels. The figure shows the MAP estimates of their weights over an ensemble of 32 particles. All of the high likelihood particles gave a large weight to the Periodic component, low weights to Linear, Quadratic and OneLayer, and a large weight to either RBF or Matern.

Parallel coordinates plot of the MAP estimates of the base kernel weights over 32 particles. The sum_of_stumps model was trained on the N374 series from the M3 dataset (insert in blue). Darker lines correspond to particles with higher likelihoods.

By using WeightedSums as the inputs to other operators, it is possible to express rich combinatorial structures, while keeping models compact and the number of learnable weights small. As an example, we include the sum_of_products model (illustrated in the figure below) which first creates a pairwise product of two WeightedSums, and then a sum of the two products. By setting some of the weights to zero, we can create many different discrete structures. The total number of possible structures in this model is 216, since there are 16 base kernels that can be turned on or off. All these structures are explored implicitly by training just this one model.

Illustration of the "sum_of_products" model. Each of the four WeightedSums have the same structure as the "sum_of_stumps" model.

We have found, however, that certain combinations of kernels (e.g., the product of Periodic and either the Matern or ExponentiatedQuadratic) lead to overfitting on many datasets. To prevent this, we have defined model classes like sum_of_safe_shallow that exclude such products when performing structure discovery with WeightedSums.

For training, AutoBNN provides AutoBnnMapEstimator and AutoBnnMCMCEstimator to perform MAP and MCMC inference, respectively. Either estimator can be combined with any of the six likelihood functions, including four based on normal distributions with different noise characteristics for continuous data and two based on the negative binomial distribution for count data.

Result from running AutoBNN on the Mauna Loa CO2 dataset in our example colab. The model captures the trend and seasonal component in the data. Extrapolating into the future, the mean prediction slightly underestimates the actual trend, while the 95% confidence interval gradually increases.

To fit a model like in the figure above, all it takes is the following 10 lines of code, using the scikit-learn–inspired estimator interface:

import autobnn as ab

model = ab.operators.Add(

estimator = ab.estimators.AutoBnnMapEstimator(
    model, 'normal_likelihood_logistic_noise', jax.random.PRNGKey(42),

estimator.fit(my_training_data_xs, my_training_data_ys)
low, mid, high = estimator.predict_quantiles(my_training_data_xs)


AutoBNN provides a powerful and flexible framework for building sophisticated time series prediction models. By combining the strengths of BNNs and GPs with compositional kernels, AutoBNN opens a world of possibilities for understanding and forecasting complex data. We invite the community to try the colab, and leverage this library to innovate and solve real-world challenges.


AutoBNN was written by Colin Carroll, Thomas Colthurst, Urs Köster and Srinivas Vasudevan. We would like to thank Kevin Murphy, Brian Patton and Feras Saad for their advice and feedback.

Source: Google AI Blog

Google Research embarks on effort to map a mouse brain

The human brain is perhaps the most computationally complex machine in existence, consisting of networks of billions of cells. Researchers currently don’t understand the full picture of how glitches in its network machinery contribute to mental illnesses and other diseases, such as dementia. However, the emerging connectomics field, which aims to precisely map the connections between every cell in the brain, could help solve that problem. While maps have only been created for simpler organisms, technological advances for mapping even larger brains can enable us to understand how the human brain works, and how to treat brain diseases.

Today, we're excited to announce that the Connectomics team at Google Research and our collaborators are launching a $33 million project to expand the frontiers of connectomics over the next five years. Supported by the Brain Research Through Advancing Innovative Neurotechnologies (BRAIN) Initiative at the National Institutes of Health (NIH) and led by researchers at Harvard University, we'll be working alongside a multidisciplinary team of experts from the Allen Institute, MIT, Cambridge University, Princeton University and Johns Hopkins University, with advisers from HHMI’s Janelia Research Campus. Our project goal is to tackle an immense challenge in neuroscience: mapping a tiny fraction (2-3%) of the mouse brain. We will specifically target the hippocampal region, which is responsible for encoding memories, attention and spatial navigation. This project is one of 11 funded by the NIH's $150 million BRAIN Initiative Connectivity Across Scales (BRAIN CONNECTS) program. Google Research is contributing computational and analytical resources to this effort, and will not receive any funding from the NIH. Our project asks a critical question: Can we scale and speed up our technologies enough to map the whole connectome of a mouse brain?

The modern era of connectomics

This effort to map the connectome of a small part of the mouse brain builds on a decade of innovation in the field, including many advances initiated by the Connectomics team at Google Research. We hope to accomplish something similar to the early days of the Human Genome Project, when scientists worked for years to sequence a small portion of the human genome as they refined technologies that would enable them to complete the rest of the genome.

In 2021, we and collaborators at Harvard successfully mapped one cubic millimeter of the human brain, which we released as the H01 dataset, a resource for studying the human brain and scaling connectomics technologies. But mapping the entire human brain connectome would require gathering and analyzing as much as a zettabyte of data (one billion terabytes), which is beyond the current capabilities of existing technologies.

Analyzing a mouse connectome is the next best thing. It is small enough to be technically feasible and could potentially deliver insights relevant to our own minds; neuroscientists already use mice to study human brain function and dysfunction. By working together to map 10–15 cubic mm of the mouse brain, we hope to develop new approaches that will allow us to map the entire remainder of the mouse brain, and the human brain thereafter.

Neuroscientists have been working for decades to map increasingly larger and more complicated connectomes.

One of biology’s largest datasets

In this connectomics project, we will map the connectome of the hippocampal formation of the mouse brain, which converts short-term memories into long-term memories and helps the mouse navigate in space. The mouse hippocampal formation is the largest area of any brain we’ve attempted to understand in this way. Through mapping this region of the mouse brain, we will create one of the largest datasets in biology, combining about 25,000 terabytes, or 25 petabytes of brain data. For reference, there are about 250 billion stars in our Milky Way Galaxy. If each of those stars was a single byte, it would take 100,000 Milky Way Galaxies to match the 25 petabytes of data that the project will collect when mapping a small region of the mouse brain.

To illustrate the hippocampal project’s scale, we calculated the number of Pixel phones (shown as stacks of Pixels below) needed to store the image data from the completed connectome projects that mapped the roundworm and fruit fly brains, as well as for the mouse hippocampal region and entire mouse brain projects, which are just getting started.

Then, we compared the heights of each Pixel stack to familiar objects and landmarks. It would take a stack of 100 Pixels, as tall as a four-year-old girl, to store the image data for the fruit fly brain, the largest completed project thus far. In contrast, the mouse hippocampal connectome effort will require storage equivalent to more than 48,800 Pixels, reaching as high as the Empire State Building. The animation below shows how the mouse hippocampal project will surpass the scale of previous connectome projects.

We are partnering with several collaborators to build a connectome (a map of the connections between brain cells) for the hippocampal region of a mouse brain. This project will create the largest connectomic dataset ever, surpassing the scale of previous projects that mapped the smaller roundworm and fruit fly brains. We hope this effort will lead to the development of new approaches that will allow us to later map an entire mouse brain. This animation shows how the field of connectomics is scaling up by calculating the number of Pixel phones needed to store the data from various projects. It would take just two Pixels, the height of an olive, to store the roundworm connectome data, while it would take a stack of Pixels the size of Mount Everest to store the data from an entire mouse connectome.

Understanding the connectome of the mouse hippocampal formation could help illuminate the way our own brains work. For instance, we may find common features between this circuitry in the mouse brain and human brains that explain how we know where we are, how our brains associate memories with specific locations, and what goes wrong in people who can’t properly form new spatial memories.

Opening the petabyte pipeline

Over the last decade, our team has worked to develop tools for managing massive connectomic datasets, and extracting scientific value from them. But a mouse brain has 1,000 times more neurons than the brain of the Drosophila fruit fly, an organism for which we helped build a connectome for a large part of the brain. Starting the mouse brain connectome will challenge us to improve existing technologies to enable us to map more data faster than ever before.

We’ll continue to refine our flood-filling networks, which use deep learning to trace, or “segment”, each neuron’s path through three-dimensional brain volumes made from electron microscope data. We’ll also extend the capabilities of our self-supervised learning technology, SegCLR, which allows us to automatically extract key insights from segmented volumes, such as identifying cell type (e.g., pyramidal neuron, basket neuron, etc.) and parts of each neuron (e.g., axon, dendrite, etc.).

A flood filling network traces a neuron through three-dimensional brain space.

We will also continue to enhance the scalability and performance of our core connectomics infrastructure, such as TensorStore for storage and Neuroglancer for visualization, in order to enable all of our computational pipelines and human analysis workflows to operate at these new scales of data. We’re eager to get to work to discover what peering into a mouse’s mind might tell us about our own.


The mouse connectomics project described in this blog post will be supported in part by the NIH BRAIN Initiative under award number 1UM1NS132250. Google Research is contributing computational and analytical resources to the mouse connectome project, and will not receive funding from the NIH. Many people were involved in the development of the technologies that make this project possible. We thank our long-term academic collaborators in the Lichtman Lab (Harvard University), HHMI Janelia, and the Denk Lab (Max Planck Institute for Biological Intelligence), and acknowledge core contributions from the Connectomics Team at Google. We also thank John Guilyard for creating the illustrative animation in this post, and Elise Kleeman, and Erika Check Hayden for their support. Thanks to Lizzie Dorfman, Michael Brenner, Jay Yagnik and Jeff Dean for their support, coordination and leadership.

Source: Google AI Blog

LayerNAS: Neural Architecture Search in Polynomial Complexity

Every byte and every operation matters when trying to build a faster model, especially if the model is to run on-device. Neural architecture search (NAS) algorithms design sophisticated model architectures by searching through a larger model-space than what is possible manually. Different NAS algorithms, such as MNasNet and TuNAS, have been proposed and have discovered several efficient model architectures, including MobileNetV3, EfficientNet.

Here we present LayerNAS, an approach that reformulates the multi-objective NAS problem within the framework of combinatorial optimization to greatly reduce the complexity, which results in an order of magnitude reduction in the number of model candidates that must be searched, less computation required for multi-trial searches, and the discovery of model architectures that perform better overall. Using a search space built on backbones taken from MobileNetV2 and MobileNetV3, we find models with top-1 accuracy on ImageNet up to 4.9% better than current state-of-the-art alternatives.

Problem formulation

NAS tackles a variety of different problems on different search spaces. To understand what LayerNAS is solving, let's start with a simple example: You are the owner of GBurger and are designing the flagship burger, which is made up with three layers, each of which has four options with different costs. Burgers taste differently with different mixtures of options. You want to make the most delicious burger you can that comes in under a certain budget.

Make up your burger with different options available for each layer, each of which has different costs and provides different benefits.

Just like the architecture for a neural network, the search space for the perfect burger follows a layerwise pattern, where each layer has several options with different changes to costs and performance. This simplified model illustrates a common approach for setting up search spaces. For example, for models based on convolutional neural networks (CNNs), like MobileNet, the NAS algorithm can select between a different number of options — filters, strides, or kernel sizes, etc. — for the convolution layer.


We base our approach on search spaces that satisfy two conditions:

  • An optimal model can be constructed using one of the model candidates generated from searching the previous layer and applying those search options to the current layer.
  • If we set a FLOP constraint on the current layer, we can set constraints on the previous layer by reducing the FLOPs of the current layer.

Under these conditions it is possible to search linearly, from layer 1 to layer n knowing that when searching for the best option for layer i, a change in any previous layer will not improve the performance of the model. We can then bucket candidates by their cost, so that only a limited number of candidates are stored per layer. If two models have the same FLOPs, but one has better accuracy, we only keep the better one, and assume this won’t affect the architecture of following layers. Whereas the search space of a full treatment would expand exponentially with layers since the full range of options are available at each layer, our layerwise cost-based approach allows us to significantly reduce the search space, while being able to rigorously reason over the polynomial complexity of the algorithm. Our experimental evaluation shows that within these constraints we are able to discover top-performance models.

NAS as a combinatorial optimization problem

By applying a layerwise-cost approach, we reduce NAS to a combinatorial optimization problem. I.e., for layer i, we can compute the cost and reward after training with a given component Si . This implies the following combinatorial problem: How can we get the best reward if we select one choice per layer within a cost budget? This problem can be solved with many different methods, one of the most straightforward of which is to use dynamic programming, as described in the following pseudo code:

while True:
	# select a candidate to search in Layer i
	candidate = select_candidate(layeri)
	if searchable(candidate):
		# Use the layerwise structural information to generate the children.
		children = generate_children(candidate)
		reward = train(children)
		bucket = bucketize(children)
		if memorial_table[i][bucket] < reward:
			memorial_table[i][bucket] = children
		move to next layer
Pseudocode of LayerNAS.
Illustration of the LayerNAS approach for the example of trying to create the best burger within a budget of $7–$9. We have four options for the first layer, which results in four burger candidates. By applying four options on the second layer, we have 16 candidates in total. We then bucket them into ranges from $1–$2, $3–$4, $5–$6, and $7–$8, and only keep the most delicious burger within each of the buckets, i.e., four candidates. Then, for those four candidates, we build 16 candidates using the pre-selected options for the first two layers and four options for each candidate for the third layer. We bucket them again, select the burgers within the budget range, and keep the best one.

Experimental results

When comparing NAS algorithms, we evaluate the following metrics:

  • Quality: What is the most accurate model that the algorithm can find?
  • Stability: How stable is the selection of a good model? Can high-accuracy models be consistently discovered in consecutive trials of the algorithm?
  • Efficiency: How long does it take for the algorithm to find a high-accuracy model?

We evaluate our algorithm on the standard benchmark NATS-Bench using 100 NAS runs, and we compare against other NAS algorithms, previously described in the NATS-Bench paper: random search, regularized evolution, and proximal policy optimization. Below, we visualize the differences between these search algorithms for the metrics described above. For each comparison, we record the average accuracy and variation in accuracy (variation is noted by a shaded region corresponding to the 25% to 75% interquartile range).

NATS-Bench size search defines a 5-layer CNN model, where each layer can choose from eight different options, each with different channels on the convolution layers. Our goal is to find the best model with 50% of the FLOPs required by the largest model. LayerNAS performance stands apart because it formulates the problem in a different way, separating the cost and reward to avoid searching a significant number of irrelevant model architectures. We found that model candidates with fewer channels in earlier layers tend to yield better performance, which explains how LayerNAS discovers better models much faster than other algorithms, as it avoids spending time on models outside the desired cost range. Note that the accuracy curve drops slightly after searching longer due to the lack of correlation between validation accuracy and test accuracy, i.e., some model architectures with higher validation accuracy have a lower test accuracy in NATS-Bench size search.

Top: NATS-Bench size search test accuracy on Cifar10; Middle: On Cifar100; Bottom: On ImageNet16-120. Average on 100 runs compared with random search (random), Regularized Evolution (evolution), and Proximal Policy Optimization (PPO).

We construct search spaces based on MobileNetV2, MobileNetV2 1.4x, MobileNetV3 Small, and MobileNetV3 Large and search for an optimal model architecture under different #MADDs (number of multiply-additions per image) constraints. Among all settings, LayerNAS finds a model with better accuracy on ImageNet. See the paper for details.

Comparison on models under different #MAdds.


In this post, we demonstrated how to reformulate NAS into a combinatorial optimization problem, and proposed LayerNAS as a solution that requires only polynomial search complexity. We compared LayerNAS with existing popular NAS algorithms and showed that it can find improved models on NATS-Bench. We also use the method to find better architectures based on MobileNetV2, and MobileNetV3.


We would like to thank Jingyue Shen, Keshav Kumar, Daiyi Peng, Mingxing Tan, Esteban Real, Peter Young, Weijun Wang, Qifei Wang, Xuanyi Dong, Xin Wang, Yingjie Miao, Yun Long, Zhuo Wang, Da-Cheng Juan, Deqiang Chen, Fotis Iliopoulos, Han-Byul Kim, Rino Lee, Andrew Howard, Erik Vee, Rina Panigrahy, Ravi Kumar and Andrew Tomkins for their contribution, collaboration and advice.

Source: Google AI Blog

Distributed differential privacy for federated learning

Federated learning is a distributed way of training machine learning (ML) models where data is locally processed and only focused model updates and metrics that are intended for immediate aggregation are shared with a server that orchestrates training. This allows the training of models on locally available signals without exposing raw data to servers, increasing user privacy. In 2021, we announced that we are using federated learning to train Smart Text Selection models, an Android feature that helps users select and copy text easily by predicting what text they want to select and then automatically expanding the selection for them.

Since that launch, we have worked to improve the privacy guarantees of this technology by carefully combining secure aggregation (SecAgg) and a distributed version of differential privacy. In this post, we describe how we built and deployed the first federated learning system that provides formal privacy guarantees to all user data before it becomes visible to an honest-but-curious server, meaning a server that follows the protocol but could try to gain insights about users from data it receives. The Smart Text Selection models trained with this system have reduced memorization by more than two-fold, as measured by standard empirical testing methods.

Scaling secure aggregation

Data minimization is an important privacy principle behind federated learning. It refers to focused data collection, early aggregation, and minimal data retention required during training. While every device participating in a federated learning round computes a model update, the orchestrating server is only interested in their average. Therefore, in a world that optimizes for data minimization, the server would learn nothing about individual updates and only receive an aggregate model update. This is precisely what the SecAgg protocol achieves, under rigorous cryptographic guarantees.

Important to this work, two recent advancements have improved the efficiency and scalability of SecAgg at Google:

  • An improved cryptographic protocol: Until recently, a significant bottleneck in SecAgg was client computation, as the work required on each device scaled linearly with the total number of clients (N) participating in the round. In the new protocol, client computation now scales logarithmically in N. This, along with similar gains in server costs, results in a protocol able to handle larger rounds. Having more users participate in each round improves privacy, both empirically and formally.
  • Optimized client orchestration: SecAgg is an interactive protocol, where participating devices progress together. An important feature of the protocol is that it is robust to some devices dropping out. If a client does not send a response in a predefined time window, then the protocol can continue without that client’s contribution. We have deployed statistical methods to effectively auto-tune such a time window in an adaptive way, resulting in improved protocol throughput.

The above improvements made it easier and faster to train Smart Text Selection with stronger data minimization guarantees.

Aggregating everything via secure aggregation

A typical federated training system not only involves aggregating model updates but also metrics that describe the performance of the local training. These are important for understanding model behavior and debugging potential training issues. In federated training for Smart Text Selection, all model updates and metrics are aggregated via SecAgg. This behavior is statically asserted using TensorFlow Federated, and locally enforced in Android’s Private Compute Core secure environment. As a result, this enhances privacy even more for users training Smart Text Selection, because unaggregated model updates and metrics are not visible to any part of the server infrastructure.

Differential privacy

SecAgg helps minimize data exposure, but it does not necessarily produce aggregates that guarantee against revealing anything unique to an individual. This is where differential privacy (DP) comes in. DP is a mathematical framework that sets a limit on an individual's influence on the outcome of a computation, such as the parameters of a ML model. This is accomplished by bounding the contribution of any individual user and adding noise during the training process to produce a probability distribution over output models. DP comes with a parameter (ε) that quantifies how much the distribution could change when adding or removing the training examples of any individual user (the smaller the better).

Recently, we announced a new method of federated training that enforces formal and meaningfully strong DP guarantees in a centralized manner, where a trusted server controls the training process. This protects against external attackers who may attempt to analyze the model. However, this approach still relies on trust in the central server. To provide even greater privacy protections, we have created a system that uses distributed differential privacy (DDP) to enforce DP in a distributed manner, integrated within the SecAgg protocol.

Distributed differential privacy

DDP is a technology that offers DP guarantees with respect to an honest-but-curious server coordinating training. It works by having each participating device clip and noise its update locally, and then aggregating these noisy clipped updates through the new SecAgg protocol described above. As a result, the server only sees the noisy sum of the clipped updates.

However, the combination of local noise addition and use of SecAgg presents significant challenges in practice:

  • An improved discretization method: One challenge is properly representing model parameters as integers in SecAgg's finite group with integer modular arithmetic, which can inflate the norm of the discretized model and require more noise for the same privacy level. For example, randomized rounding to the nearest integers could inflate the user's contribution by a factor equal to the number of model parameters. We addressed this by scaling the model parameters, applying a random rotation, and rounding to nearest integers. We also developed an approach for auto-tuning the discretization scale during training. This led to an even more efficient and accurate integration between DP and SecAgg.
  • Optimized discrete noise addition: Another challenge is devising a scheme for choosing an arbitrary number of bits per model parameter without sacrificing end-to-end privacy guarantees, which depend on how the model updates are clipped and noised. To address this, we added integer noise in the discretized domain and analyzed the DP properties of sums of integer noise vectors using the distributed discrete Gaussian and distributed Skellam mechanisms.
An overview of federated learning with distributed differential privacy.

We tested our DDP solution on a variety of benchmark datasets and in production and validated that we can match the accuracy to central DP with a SecAgg finite group of size 12 bits per model parameter. This meant that we were able to achieve added privacy advantages while also reducing memory and communication bandwidth. To demonstrate this, we applied this technology to train and launch Smart Text Selection models. This was done with an appropriate amount of noise chosen to maintain model quality. All Smart Text Selection models trained with federated learning now come with DDP guarantees that apply to both the model updates and metrics seen by the server during training. We have also open sourced the implementation in TensorFlow Federated.

Empirical privacy testing

While DDP adds formal privacy guarantees to Smart Text Selection, those formal guarantees are relatively weak (a finite but large ε, in the hundreds). However, any finite ε is an improvement over a model with no formal privacy guarantee for several reasons: 1) A finite ε moves the model into a regime where further privacy improvements can be quantified; and 2) even large ε’s can indicate a substantial decrease in the ability to reconstruct training data from the trained model. To get a more concrete understanding of the empirical privacy advantages, we performed thorough analyses by applying the Secret Sharer framework to Smart Text Selection models. Secret Sharer is a model auditing technique that can be used to measure the degree to which models unintentionally memorize their training data.

To perform Secret Sharer analyses for Smart Text Selection, we set up control experiments which collect gradients using SecAgg. The treatment experiments use distributed differential privacy aggregators with different amounts of noise.

We found that even low amounts of noise reduce memorization meaningfully, more than doubling the Secret Sharer rank metric for relevant canaries compared to the baseline. This means that even though the DP ε is large, we empirically verified that these amounts of noise already help reduce memorization for this model. However, to further improve on this and to get stronger formal guarantees, we aim to use even larger noise multipliers in the future.

Next steps

We developed and deployed the first federated learning and distributed differential privacy system that comes with formal DP guarantees with respect to an honest-but-curious server. While offering substantial additional protections, a fully malicious server might still be able to get around the DDP guarantees either by manipulating the public key exchange of SecAgg or by injecting a sufficient number of "fake" malicious clients that don’t add the prescribed noise into the aggregation pool. We are excited to address these challenges by continuing to strengthen the DP guarantee and its scope.


The authors would like to thank Adria Gascon for significant impact on the blog post itself, as well as the people who helped develop these ideas and bring them to practice: Ken Liu, Jakub Konečný, Brendan McMahan, Naman Agarwal, Thomas Steinke, Christopher Choquette, Adria Gascon, James Bell, Zheng Xu, Asela Gunawardana, Kallista Bonawitz, Mariana Raykova, Stanislav Chiknavaryan, Tancrède Lepoint, Shanshan Wu, Yu Xiao, Zachary Charles, Chunxiang Zheng, Daniel Ramage, Galen Andrew, Hugo Song, Chang Li, Sofia Neata, Ananda Theertha Suresh, Timon Van Overveldt, Zachary Garrett, Wennan Zhu, and Lukas Zilka. We’d also like to thank Tom Small for creating the animated figure.

Source: Google AI Blog

Teaching old labels new tricks in heterogeneous graphs

Industrial applications of machine learning are commonly composed of various items that have differing data modalities or feature distributions. Heterogeneous graphs (HGs) offer a unified view of these multimodal data systems by defining multiple types of nodes (for each data type) and edges (for the relation between data items). For instance, e-commerce networks might have [user, product, review] nodes or video platforms might have [channel, user, video, comment] nodes. Heterogeneous graph neural networks (HGNNs) learn node embeddings summarizing each node’s relationships into a vector. However, in real world HGs, there is often a label imbalance issue between different node types. This means that label-scarce node types cannot exploit HGNNs, which hampers the broader applicability of HGNNs.

In “Zero-shot Transfer Learning within a Heterogeneous Graph via Knowledge Transfer Networks”, presented at NeurIPS 2022, we propose a model called a Knowledge Transfer Network (KTN), which transfers knowledge from label-abundant node types to zero-labeled node types using the rich relational information given in a HG. We describe how we pre-train a HGNN model without the need for fine-tuning. KTNs outperform state-of-the-art transfer learning baselines by up to 140% on zero-shot learning tasks, and can be used to improve many existing HGNN models on these tasks by 24% (or more).

KTNs transform labels from one type of information (squares) through a graph to another type (stars).

What is a heterogeneous graph?

A HG is composed of multiple node and edge types. The figure below shows an e-commerce network presented as a HG. In e-commerce, “users” purchase “products” and write “reviews”. A HG presents this ecosystem using three node types [user, product, review] and three edge types [user-buy-product, user-write-review, review-on-product]. Individual products, users, and reviews are then presented as nodes and their relationships as edges in the HG with the corresponding node and edge types.

E-commerce heterogeneous graph.

In addition to all connectivity information, HGs are commonly given with input node attributes that summarize each node’s information. Input node attributes could have different modalities across different node types. For instance, images of products could be given as input node attributes for the product nodes, while text can be given as input attributes to review nodes. Node labels (e.g., the category of each product or the category that most interests each user) are what we want to predict on each node.

HGNNs and label scarcity issues

HGNNs compute node embeddings that summarize each node’s local structures (including the node and its neighbor’s information). These node embeddings are utilized by a classifier to predict each node’s label. To train a HGNN model and a classifier to predict labels for a specific node type, we require a good amount of labels for the type.

A common issue in industrial applications of deep learning is label scarcity, and with their diverse node types, HGNNs are even more likely to face this challenge. For instance, publicly available content node types (e.g., product nodes) are abundantly labeled, whereas labels for user or account nodes may not be available due to privacy restrictions. This means that in most standard training settings, HGNN models can only learn to make good inferences for a few label-abundant node types and can usually not make any inferences for any remaining node types (given the absence of any labels for them).

Transfer learning on heterogeneous graphs

Zero-shot transfer learning is a technique used to improve the performance of a model on a target domain with no labels by using the knowledge learned by the model from another related source domain with adequately labeled data. To apply transfer learning to solve this label scarcity issue for certain node types in HGs, the target domain would be the zero-labeled node types. Then what would be the source domain? Previous work commonly sets the source domain as the same type of nodes located in a different HG, assuming those nodes are abundantly labeled. This graph-to-graph transfer learning approach pre-trains a HGNN model on the external HG and then runs the model on the original (label-scarce) HG.

However, these approaches are not applicable in many real-world scenarios for three reasons. First, any external HG that could be used in a graph-to-graph transfer learning setting would almost surely be proprietary, thus, likely unavailable. Second, even if practitioners could obtain access to an external HG, it is unlikely the distribution of that source HG would match their target HG well enough to apply transfer learning. Finally, node types suffering from label scarcity are likely to suffer the same issue on other HGs (e.g., privacy issues on user nodes).

Our approach: Transfer learning between node types within a heterogeneous graph

Here, we shed light on a more practical source domain, other node types with abundant labels located on the same HG. Instead of using extra HGs, we transfer knowledge within a single HG (assumed to be fully owned by the practitioners) across different types of nodes. More specifically, we pre-train a HGNN model and a classifier on a label-abundant (source) node type, then reuse the models on the zero-labeled (target) node types located in the same HG without additional fine-tuning. The one requirement is that the source and target node types share the same label set (e.g., in the e-commerce HG, product nodes have a label set describing product categories, and user nodes share the same label set describing their favorite shopping categories).

Why is it challenging?

Unfortunately, we cannot directly reuse the pre-trained HGNN and classifier on the target node type. One crucial characteristic of HGNN architectures is that they are composed of modules specialized to each node type to fully learn the multiplicity of HGs. HGNNs use distinct sets of modules to compute embeddings for each node type. In the figure below, blue- and red-colored modules are used to compute node embeddings for the source and target node types, respectively.

HGNNs are composed of modules specialized to each node type and use distinct sets of modules to compute embeddings of different node types. More details can be found in the paper.

While pre-training HGNNs on the source node type, source-specific modules in the HGNNs are well trained, however target-specific modules are under-trained as they have only a small amount of gradients flowing into them. This is shown below, where we see that the L2 norm of gradients for target node types (i.e., Mtt) are much lower than for source types (i.e., Mss). In this case a HGNN model outputs poor node embeddings for the target node type, which results in poor task performance.

In HGNNs, target type-specific modules receive zero or only a small amount of gradients during pre-training on the source node type, leading to poor performance on the target node type.

KTN: Trainable cross-type transfer learning for HGNNs

Our work focuses on transforming the (poor) target node embeddings computed by a pre-trained HGNN model to follow the distribution of the source node embeddings. Then the classifier, pre-trained on the source node type, can be reused for the target node type. How can we map the target node embeddings to the source domain? To answer this question, we investigate how HGNNs compute node embeddings to learn the relationship between source and target distributions.

HGNNs aggregate connected node embeddings to augment a target node’s embeddings in each layer. In other words, the node embeddings for both source and target node types are updated using the same input — the previous layer’s node embeddings of any connected node types. This means that they can be represented by each other. We prove this relationship theoretically and find there is a mapping matrix (defined by HGNN parameters) from the target domain to the source domain (more details in Theorem 1 in the paper). Based on this theorem, we introduce an auxiliary neural network, which we refer to as a Knowledge Transfer Network (KTN), that receives the target node embeddings and then transforms them by multiplying them with a (trainable) mapping matrix. We then define a regularizer that is minimized along with the performance loss in the pre-training phase to train the KTN. At test time, we map the target embeddings computed from the pre-trained HGNN to the source domain using the trained KTN for classification.

In HGNNs, the final node embeddings of both source and target types are computed from different mathematical functions (f(): source, g(): target) which use the same input — the previous layer’s node embeddings.

Experimental results

To examine the effectiveness of KTNs, we ran 18 different zero-shot transfer learning tasks on two public heterogeneous graphs, Open Academic Graph and Pubmed. We compare KTN with eight state-of-the-art transfer learning methods (DAN, JAN, DANN, CDAN, CDAN-E, WDGRL, LP, EP). Shown below, KTN consistently outperforms all baselines on all tasks, beating transfer learning baselines by up to 140% (as measured by Normalized Discounted Cumulative Gain, a ranking metric).

Zero-shot transfer learning on Open Academic Graph (OAG-CS) and Pubmed datasets. The colors represent different categories of transfer learning baselines against which the results are compared. Yellow: Use statistical properties (e.g., mean, variance) of distributions. Green: Use adversarial models to transfer knowledge. Orange: Transfer knowledge directly via graph structure using label propagation.

Most importantly, KTN can be applied to almost all HGNN models that have node and edge type-specific parameters and improve their zero-shot performance on target domains. As shown below, KTN improves accuracy on zero-labeled node types across six different HGNN models(R-GCN, HAN, HGT, MAGNN, MPNN, H-MPNN) by up to 190%.

KTN can be applied to six different HGNN models and improve their zero-shot performance on target domains.


Various ecosystems in industry can be presented as heterogeneous graphs. HGNNs summarize heterogeneous graph information into effective representations. However, label scarcity issues on certain types of nodes prevent the wider application of HGNNs. In this post, we introduced KTN, the first cross-type transfer learning method designed for HGNNs. With KTN, we can fully exploit the richness of heterogeneous graphs via HGNNs regardless of label scarcity. See the paper for more details.


This paper is joint work with our co-authors John Palowitch (Google Research), Dustin Zelle (Google Research), Ziniu Hu (Intern, Google Research), and Russ Salakhutdinov (CMU). We thank Tom Small for creating the animated figure in this blog post.

Source: Google AI Blog

Accelerating Text Generation with Confident Adaptive Language Modeling (CALM)

Language models (LMs) are the driving force behind many recent breakthroughs in natural language processing. Models like T5, LaMDA, GPT-3, and PaLM have demonstrated impressive performance on various language tasks. While multiple factors can contribute to improving the performance of LMs, some recent studies suggest that scaling up the model’s size is crucial for revealing emergent capabilities. In other words, some instances can be solved by small models, while others seem to benefit from increased scale.

Despite recent efforts that enabled the efficient training of LMs over large amounts of data, trained models can still be slow and costly for practical use. When generating text at inference time, most autoregressive LMs output content similar to how we speak and write (word after word), predicting each new word based on the preceding words. This process cannot be parallelized since LMs need to complete the prediction of one word before starting to compute the next one. Moreover, predicting each word requires significant computation given the model’s billions of parameters.

In “Confident Adaptive Language Modeling”, presented at NeurIPS 2022, we introduce a new method for accelerating the text generation of LMs by improving efficiency at inference time. Our method, named CALM, is motivated by the intuition that some next word predictions are easier than others. When writing a sentence, some continuations are trivial, while others might require more effort. Current LMs devote the same amount of compute power for all predictions. Instead, CALM dynamically distributes the computational effort across generation timesteps. By selectively allocating more computational resources only to harder predictions, CALM generates text faster while preserving output quality.

Confident Adaptive Language Modeling

When possible, CALM skips some compute effort for certain predictions. To demonstrate this, we use the popular encoder-decoder T5 architecture. The encoder reads the input text (e.g., a news article to summarize) and converts the text to dense representations. Then, the decoder outputs the summary by predicting it word by word. Both the encoder and decoder include a long sequence of Transformer layers. Each layer includes attention and feedforward modules with many matrix multiplications. These layers gradually modify the hidden representation that is ultimately used for predicting the next word.

Instead of waiting for all decoder layers to complete, CALM attempts to predict the next word earlier, after some intermediate layer. To decide whether to commit to a certain prediction or to postpone the prediction to a later layer, we measure the model’s confidence in its intermediate prediction. The rest of the computation is skipped only when the model is confident enough that the prediction won’t change. For quantifying what is “confident enough”, we calibrate a threshold that statistically satisfies arbitrary quality guarantees over the full output sequence.

Text generation with a regular language model (top) and with CALM (bottom). CALM attempts to make early predictions. Once confident enough (darker blue tones), it skips ahead and saves time.

Language Models with Early Exits

Enabling this early exit strategy for LMs requires minimal modifications to the training and inference processes. During training, we encourage the model to produce meaningful representations in intermediate layers. Instead of predicting only using the top layer, our learning loss function is a weighted average over the predictions of all layers, assigning higher weight to top layers. Our experiments demonstrate that this significantly improves the intermediate layer predictions while preserving the full model’s performance. In one model variant, we also include a small early-exit classifier trained to classify if the local intermediate layer prediction is consistent with the top layer. We train this classifier in a second quick step where we freeze the rest of the model.

Once the model is trained, we need a method to allow early-exiting. First, we define a local confidence measure for capturing the model’s confidence in its intermediate prediction. We explore three confidence measures (described in the results section below): (1) softmax response, taking the maximum predicted probability out of the softmax distribution; (2) state propagation, the cosine distance between the current hidden representation and the one from the previous layer; and (3) early-exit classifier, the output of a classifier specifically trained for predicting local consistency. We find the softmax response to be statistically strong while being simple and fast to compute. The other two alternatives are lighter in floating point operations (FLOPS).

Another challenge is that the self-attention of each layer depends on hidden-states from previous words. If we exit early for some word predictions, these hidden-states might be missing. Instead, we attend back to the hidden state of the last computed layer.

Finally, we set up the local confidence threshold for exiting early. In the next section, we describe our controlled process for finding good threshold values. As a first step, we simplify this infinite search space by building on a useful observation: mistakes that are made at the beginning of the generation process are more detrimental since they can affect all of the following outputs. Therefore, we start with a higher (more conservative) threshold, and gradually reduce it with time. We use a negative exponent with user-defined temperature to control this decay rate. We find this allows better control over the performance-efficiency tradeoff (the obtained speedup per quality level).

Reliably Controlling the Quality of the Accelerated Model

Early exit decisions have to be local; they need to happen when predicting each word. In practice, however, the final output should be globally consistent or comparable to the original model. For example, if the original full model generated “the concert was wonderful and long”, one would accept CALM switching the order of the adjectives and outputting “the concert was long and wonderful”. However, at the local level, the word “wonderful” was replaced with “long”. Therefore, the two outputs are globally consistent, but include some local inconsistencies. We build on the Learn then Test (LTT) framework to connect local confidence-based decisions to globally consistent outputs.

In CALM, local per-timestep confidence thresholds for early exiting decisions are derived, via LTT calibration, from user-defined consistency constraints over the full output text. Red boxes indicate that CALM used most of the decoder’s layers for that specific prediction. Green boxes indicate that CALM saved time by using only a few Transformer layers. Full sentence shown in the last example of this post.

First, we define and formulate two types of consistency constraints from which to choose:

  1. Textual consistency: We bound the expected textual distance between the outputs of CALM and the outputs of the full model. This doesn’t require any labeled data.
  2. Risk consistency: We bound the expected increase in loss that we allow for CALM compared to the full model. This requires reference outputs against which to compare.

For each of these constraints, we can set the tolerance that we allow and calibrate the confidence threshold to allow early exits while reliably satisfying our defined constraint with an arbitrarily high probability.

CALM Saves Inference Time

We run experiments on three popular generation datasets: CNN/DM for summarization, WMT for machine translation, and SQuAD for question answering. We evaluate each of the three confidence measures (softmax response, state propagation and early-exit classifier) using an 8-layer encoder-decoder model. To evaluate global sequence-level performance, we use the standard Rouge-L, BLEU, and Token-F1 scores that measure distances against human-written references. We show that one can maintain full model performance while using only a third or half of the layers on average. CALM achieves this by dynamically distributing the compute effort across the prediction timesteps.

As an approximate upper bound, we also compute the predictions using a local oracle confidence measure, which enables exiting at the first layer that leads to the same prediction as the top one. On all three tasks, the oracle measure can preserve full model performance when using only 1.5 decoder layers on average. In contrast to CALM, a static baseline uses the same number of layers for all predictions, requiring 3 to 7 layers (depending on the dataset) to preserve its performance. This demonstrates why the dynamic allocation of compute effort is important. Only a small fraction of the predictions require most of the model’s complexity, while for others much less should suffice.

Performance per task against the average number of decoder layers used.

Finally, we also find that CALM enables practical speedups. When benchmarking on TPUs, we saved almost half of the compute time while maintaining the quality of the outputs.

Example of a generated news summary. The top cell presents the reference human-written summary. Below is the prediction of the full model (8 layers) followed by two different CALM output examples. The first CALM output is 2.9x faster and the second output is 3.6x faster than the full model, benchmarked on TPUs.


CALM allows faster text generation with LMs, without reducing the quality of the output text. This is achieved by dynamically modifying the amount of compute per generation timestep, allowing the model to exit the computational sequence early when confident enough.

As language models continue to grow in size, studying how to efficiently use them becomes crucial. CALM is orthogonal and can be combined with many efficiency related efforts, including model quantization, distillation, sparsity, effective partitioning, and distributed control flows.


It was an honor and privilege to work on this with Adam Fisch, Ionel Gog, Seungyeon Kim, Jai Gupta, Mostafa Dehghani, Dara Bahri, Vinh Q. Tran, Yi Tay, and Donald Metzler. We also thank Anselm Levskaya, Hyung Won Chung, Tao Wang, Paul Barham, Michael Isard, Orhan Firat, Carlos Riquelme, Aditya Menon, Zhifeng Chen, Sanjiv Kumar, and Jeff Dean for helpful discussions and feedback. Finally, we thank Tom Small for preparing the animation in this blog post.

Source: Google AI Blog

Enhancing Backpropagation via Local Loss Optimization

While model design and training data are key ingredients in a deep neural network’s (DNN’s) success, less-often discussed is the specific optimization method used for updating the model parameters (weights). Training DNNs involves minimizing a loss function that measures the discrepancy between the ground truth labels and the model’s predictions. Training is carried out by backpropagation, which adjusts the model weights via gradient descent steps. Gradient descent, in turn, updates the weights by using the gradient (i.e., derivative) of the loss with respect to the weights.

The simplest weight update corresponds to stochastic gradient descent, which, in every step, moves the weights in the negative direction with respect to the gradients (with an appropriate step size, a.k.a. the learning rate). More advanced optimization methods modify the direction of the negative gradient before updating the weights by using information from the past steps and/or the local properties (such as the curvature information) of the loss function around the current weights. For instance, a momentum optimizer encourages moving along the average direction of past updates, and the AdaGrad optimizer scales each coordinate based on the past gradients. These optimizers are commonly known as first-order methods since they generally modify the update direction using only information from the first-order derivative (i.e., gradient). More importantly, the components of the weight parameters are treated independently from each other.

More advanced optimization, such as Shampoo and K-FAC, capture the correlations between gradients of parameters and have been shown to improve convergence, reducing the number of iterations and improving the quality of the solution. These methods capture information about the local changes of the derivatives of the loss, i.e., changes in gradients. Using this additional information, higher-order optimizers can discover much more efficient update directions for training models by taking into account the correlations between different groups of parameters. On the downside, calculating higher-order update directions is computationally more expensive than first-order updates. The operation uses more memory for storing statistics and involves matrix inversion, thus hindering the applicability of higher-order optimizers in practice.

In “LocoProp: Enhancing BackProp via Local Loss Optimization”, we introduce a new framework for training DNN models. Our new framework, LocoProp, conceives neural networks as a modular composition of layers. Generally, each layer in a neural network applies a linear transformation on its inputs, followed by a non-linear activation function. In the new construction, each layer is allotted its own weight regularizer, output target, and loss function. The loss function of each layer is designed to match the activation function of the layer. Using this formulation, training minimizes the local losses for a given mini-batch of examples, iteratively and in parallel across layers. Our method performs multiple local updates per batch of examples using a first-order optimizer (like RMSProp), which avoids computationally expensive operations such as the matrix inversions required for higher-order optimizers. However, we show that the combined local updates look rather like a higher-order update. Empirically, we show that LocoProp outperforms first-order methods on a deep autoencoder benchmark and performs comparably to higher-order optimizers, such as Shampoo and K-FAC, without the high memory and computation requirements.

Neural networks are generally viewed as composite functions that transform model inputs into output representations, layer by layer. LocoProp adopts this view while decomposing the network into layers. In particular, instead of updating the weights of the layer to minimize the loss function at the output, LocoProp applies pre-defined local loss functions specific to each layer. For a given layer, the loss function is selected to match the activation function, e.g., a tanh loss would be selected for a layer with a tanh activation. Each layerwise loss measures the discrepancy between the layer's output (for a given mini-batch of examples) and a notion of a target output for that layer. Additionally, a regularizer term ensures that the updated weights do not drift too far from the current values. The combined layerwise loss function (with a local target) plus regularizer is used as the new objective function for each layer.

Similar to backpropagation, LocoProp applies a forward pass to compute the activations. In the backward pass, LocoProp sets per neuron "targets" for each layer. Finally, LocoProp splits model training into independent problems across layers where several local updates can be applied to each layer's weights in parallel.

Perhaps the simplest loss function one can think of for a layer is the squared loss. While the squared loss is a valid choice of a loss function, LocoProp takes into account the possible non-linearity of the activation functions of the layers and applies layerwise losses tailored to the activation function of each layer. This enables the model to emphasize regions at the input that are more important for the model prediction while deemphasizing the regions that do not affect the output as much. Below we show examples of tailored losses for the tanh and ReLU activation functions.

Loss functions induced by the (left) tanh and (right) ReLU activation functions. Each loss is more sensitive to the regions affecting the output prediction. For instance, ReLU loss is zero as long as both the prediction (â) and the target (a) are negative. This is because the ReLU function applied to any negative number equals zero.

After forming the objective in each layer, LocoProp updates the layer weights by repeatedly applying gradient descent steps on its objective. The update typically uses a first-order optimizer (like RMSProp). However, we show that the overall behavior of the combined updates closely resembles higher-order updates (shown below). Thus, LocoProp provides training performance close to what higher-order optimizers achieve without the high memory or computation needed for higher-order methods, such as matrix inverse operations. We show that LocoProp is a flexible framework that allows the recovery of well-known algorithms and enables the construction of new algorithms via different choices of losses, targets, and regularizers. LocoProp’s layerwise view of neural networks also allows updating the weights in parallel across layers.

In our paper, we describe experiments on the deep autoencoder model, which is a commonly used baseline for evaluating the performance of optimization algorithms. We perform extensive tuning on multiple commonly used first-order optimizers, including SGD, SGD with momentum, AdaGrad, RMSProp, and Adam, as well as the higher-order Shampoo and K-FAC optimizers, and compare the results with LocoProp. Our findings indicate that the LocoProp method performs significantly better than first-order optimizers and is comparable to those of higher-order, while being significantly faster when run on a single GPU.

Train loss vs. number of epochs (left) and wall-clock time, i.e., the real time that passes during training, (right) for RMSProp, Shampoo, K-FAC, and LocoProp on the deep autoencoder model.

Summary and Future Directions
We introduced a new framework, called LocoProp, for optimizing deep neural networks more efficiently. LocoProp decomposes neural networks into separate layers with their own regularizer, output target, and loss function and applies local updates in parallel to minimize the local objectives. While using first-order updates for the local optimization problems, the combined updates closely resemble higher-order update directions, both theoretically and empirically.

LocoProp provides flexibility to choose the layerwise regularizers, targets, and loss functions. Thus, it allows the development of new update rules based on these choices. Our code for LocoProp is available online on GitHub. We are currently working on scaling up ideas induced by LocoProp to much larger scale models; stay tuned!

We would like to thank our co-author, Manfred K. Warmuth, for his critical contributions and inspiring vision. We would like to thank Sameer Agarwal for discussions looking at this work from a composite functions perspective, Vineet Gupta for discussions and development of Shampoo, Zachary Nado on K-FAC, Tom Small for development of the animation used in this blogpost and finally, Yonghui Wu and Zoubin Ghahramani for providing us with a nurturing research environment in the Google Brain Team.

Source: Google AI Blog

Using Deep Learning to Annotate the Protein Universe

Proteins are essential molecules found in all living things. They play a central role in our bodies’ structure and function, and they are also featured in many products that we encounter every day, from medications to household items like laundry detergent. Each protein is a chain of amino acid building blocks, and just as an image may include multiple objects, like a dog and a cat, a protein may also have multiple components, which are called protein domains. Understanding the relationship between a protein’s amino acid sequence — for example, its domains — and its structure or function are long-standing challenges with far-reaching scientific implications.

An example of a protein with known structure, TrpCF from E. coli, for which areas used by a model to predict function are highlighted (green). This protein produces tryptophan, which is an essential part of a person’s diet.

Many are familiar with recent advances in computationally predicting protein structure from amino acid sequences, as seen with DeepMind’s AlphaFold. Similarly, the scientific community has a long history of using computational tools to infer protein function directly from sequences. For example, the widely-used protein family database Pfam contains numerous highly-detailed computational annotations that describe a protein domain's function, e.g., the globin and trypsin families. While existing approaches have been successful at predicting the function of hundreds of millions of proteins, there are still many more with unknown functions — for example, at least one-third of microbial proteins are not reliably annotated. As the volume and diversity of protein sequences in public databases continue to increase rapidly, the challenge of accurately predicting function for highly divergent sequences becomes increasingly pressing.

In “Using Deep Learning to Annotate the Protein Universe”, published in Nature Biotechnology, we describe a machine learning (ML) technique to reliably predict the function of proteins. This approach, which we call ProtENN, has enabled us to add about 6.8 million entries to Pfam’s well-known and trusted set of protein function annotations, about equivalent to the sum of progress over the last decade, which we are releasing as Pfam-N. To encourage further research in this direction, we are releasing the ProtENN model and a distill-like interactive article where researchers can experiment with our techniques. This interactive tool allows the user to enter a sequence and get results for a predicted protein function in real time, in the browser, with no setup required. In this post, we’ll give an overview of this achievement and how we’re making progress toward revealing more of the protein universe.

The Pfam database is a large collection of protein families and their sequences. Our ML model ProtENN helped annotate 6.8 million more protein regions in the database.

Protein Function Prediction as a Classification Problem
In computer vision, it’s common to first train a model for image classification tasks, like CIFAR-100, before extending it to more specialized tasks, like object detection and localization. Similarly, we develop a protein domain classification model as a first step towards future models for classification of entire protein sequences. We frame the problem as a multi-class classification task in which we predict a single label out of 17,929 classes — all classes contained in the Pfam database — given a protein domain’s sequence of amino acids.

Models that Link Sequence to Function
While there are a number of models currently available for protein domain classification, one drawback of the current state-of-the-art methods is that they are based on the alignment of linear sequences and don’t consider interactions between amino acids in different parts of protein sequences. But proteins don’t just stay as a line of amino acids, they fold in on themselves such that nonadjacent amino acids have strong effects on each other.

Aligning a new query sequence to one or more sequences with known function is a key step of current state-of-the-art methods. This reliance on sequences with known function makes it challenging to predict a new sequence’s function if it is highly dissimilar to any sequence with known function. Furthermore, alignment-based methods are computationally intensive, and applying them to large datasets, such as the metagenomic database MGnify, which contains >1 billion protein sequences, can be cost prohibitive.

To address these challenges, we propose to use dilated convolutional neural networks (CNNs), which should be well-suited to modeling non-local pairwise amino-acid interactions and can be run on modern ML hardware like GPUs. We train 1-dimensional CNNs to predict the classification of protein sequences, which we call ProtCNN, as well as an ensemble of independently trained ProtCNN models, which we call ProtENN. Our goal for using this approach is to add knowledge to the scientific literature by developing a reliable ML approach that complements traditional alignment-based methods. To demonstrate this, we developed a method to accurately measure our method's accuracy.

Evaluation with Evolution in Mind
Similar to well-known classification problems in other fields, the challenge in protein function prediction is less in developing a completely new model for the task, and more in creating fair training and test sets to ensure that the models will make accurate predictions for unseen data. Because proteins have evolved from shared common ancestors, different proteins often share a substantial fraction of their amino acid sequence. Without proper care, the test set could be dominated by samples that are highly similar to the training data, which could lead to the models performing well by simply “memorizing” the training data, rather than learning to generalize more broadly from it.

We create a test set that requires ProtENN to generalize well on data far from its training set.

To guard against this, it is essential to evaluate model performance using multiple separate setups. For each evaluation, we stratify model accuracy as a function of similarity between each held-out test sequence and the nearest sequence in the train set.

The first evaluation includes a clustered split training and test set, consistent with prior literature. Here, protein sequence samples are clustered by sequence similarity, and entire clusters are placed into either the train or test sets. As a result, every test example is at least 75% different from every training example. Strong performance on this task demonstrates that a model can generalize to make accurate predictions for out-of-distribution data.

For the second evaluation, we use a randomly split training and test set, where we stratify examples based on an estimate of how difficult they will be to classify. These measures of difficulty include: (1) the similarity between a test example and the nearest training example, and (2) the number of training examples from the true class (it is much more difficult to accurately predict function given just a handful of training examples).

To place our work in context, we evaluate the performance of the most widely used baseline models and evaluation setups, with the following baseline models in particular: (1) BLAST, a nearest-neighbor method that uses sequence alignment to measure distance and infer function, and (2) profile hidden Markov models (TPHMM and phmmer). For each of these, we include the stratification of model performance based on sequence alignment similarity mentioned above. We compared these baselines against ProtCNN and the ensemble of CNNs, ProtENN.

We measure each model’s ability to generalize, from the hardest examples (left) to the easiest (right).

Reproducible and Interpretable Results
We also worked with the Pfam team to test whether our methodological proof of concept could be used to label real-world sequences. We demonstrated that ProtENN learns complementary information to alignment-based methods, and created an ensemble of the two approaches to label more sequences than either method could by itself. We publicly released the results of this effort, Pfam-N, a set of 6.8 million new protein sequence annotations.

After seeing the success of these methods and classification tasks, we inspected these networks to understand whether the embeddings were generally useful. We built a tool that enables users to explore the relation between the model predictions, embeddings, and input sequences, which we have made available through our interactive manuscript, and we found that similar sequences were clustered together in embedding space. Furthermore, the network architecture that we selected, a dilated CNN, allows us to employ previously-discovered interpretability methods like class activation mapping (CAM) and sufficient input subsets (SIS) to identify the sub-sequences responsible for the neural network predictions. With this approach, we find that our network generally focuses on the relevant elements of a sequence to predict its function.

Conclusion and Future Work
We’re excited about the progress we’ve seen by applying ML to the understanding of protein structure and function over the last few years, which has been reflected in contributions from the broader research community, from AlphaFold and CAFA to the multitude of workshops and research presentations devoted to this topic at conferences. As we look to build on this work, we think that continuing to collaborate with scientists across the field who’ve shared their expertise and data, combined with advances in ML will help us further reveal the protein universe.

We’d like to thank all of the co-authors of the manuscripts, Maysam Moussalem, Jamie Smith, Eli Bixby, Babak Alipanahi, Shanqing Cai, Cory McLean, Abhinay Ramparasad, Steven Kearnes, Zack Nado, and Tom Small.

Source: Google AI Blog

Machine Learning for Mechanical Ventilation Control

Mechanical ventilators provide critical support for patients who have difficulty breathing or are unable to breathe on their own. They see frequent use in scenarios ranging from routine anesthesia, to neonatal intensive care and life support during the COVID-19 pandemic. A typical ventilator consists of a compressed air source, valves to control the flow of air into and out of the lungs, and a "respiratory circuit" that connects the ventilator to the patient. In some cases, a sedated patient may be connected to the ventilator via a tube inserted through the trachea to their lungs, a process called invasive ventilation.

A mechanical ventilator takes breaths for patients who are not fully capable of doing so on their own. In invasive ventilation, a controllable, compressed air source is connected to a sedated patient via tubing called a respiratory circuit.

In both invasive and non-invasive ventilation, the ventilator follows a clinician-prescribed breathing waveform based on a respiratory measurement from the patient (e.g., airway pressure, tidal volume). In order to prevent harm, this demanding task requires both robustness to differences or changes in patients' lungs and adherence to the desired waveform. Consequently, ventilators require significant attention from highly-trained clinicians in order to ensure that their performance matches the patients’ needs and that they do not cause lung damage.

Example of a clinician-prescribed breathing waveform (orange) in units of airway pressure and the actual pressure (blue), given some controller algorithm.

In “Machine Learning for Mechanical Ventilation Control”, we present exploratory research into the design of a deep learning–based algorithm to improve medical ventilator control for invasive ventilation. Using signals from an artificial lung, we design a control algorithm that measures airway pressure and computes necessary adjustments to the airflow to better and more consistently match prescribed values. Compared to other approaches, we demonstrate improved robustness and better performance while requiring less manual intervention from clinicians, which suggests that this approach could reduce the likelihood of harm to a patient’s lungs.

Current Methods
Today, ventilators are controlled with methods belonging to the PID family (i.e., Proportional, Integral, Differential), which control a system based on the history of errors between the observed and desired states. A PID controller uses three characteristics for ventilator control: proportion (“P”) — a comparison of the measured and target pressure; integral (“I”) — the sum of previous measurements; and differential (“D”) — the difference between two previous measurements. Variants of PID have been used since the 17th century and today form the basis of many controllers in both industrial (e.g., controlling heat or fluids) and consumer (e.g., controlling espresso pressure) applications.

PID control forms a solid baseline, relying on the sharp reactivity of P control to rapidly increase lung pressure when breathing in and the stability of I control to hold the breath in before exhaling. However, operators must tune the ventilator for specific patients, often repeatedly, to balance the “ringing” of overzealous P control against the ineffectually slow rise in lung pressure of dominant I control.

Current PID methods are prone to over- and then under-shooting their target (ringing). Because patients differ in their physiology and may even change during treatment, highly-trained clinicians must constantly monitor and adjust existing methods to ensure such violent ringing as in the above example does not occur.

To more effectively balance these characteristics, we propose a neural network–based controller to create a set of control signals that are more broad and adaptable than PID-generated controls.

A Machine-Learned Ventilator Controller
While one could tune the coefficients of a PID controller (either manually or via an exhaustive grid search) through a limited number of repeated trials, it is impossible to apply such a direct approach towards a deep controller, as deep neural networks (DNNs) are often parameter-rich and require significant training data. Similarly, popular model-free approaches, such as Q-Learning or Policy Gradient, are data-intensive and therefore unsuitable for the physical system at hand. Further, these approaches don't take into account the intrinsic differentiability of the ventilator dynamical system, which is deterministic, continuous and contact-free.

We therefore adopt a model-based approach, where we first learn a DNN-based simulator of the ventilator-patient dynamical system. An advantage of learning such a simulator is that it provides a more accurate data-driven alternative to physics-based models, and can be more widely distributed for controller research.

To train a faithful simulator, we built a dataset by exploring the space of controls and the resulting pressures, while balancing against physical safety, e.g., not over-inflating a test lung and causing damage. Though PID control can exhibit ringing behavior, it performs well enough to use as a baseline for generating training data. To safely explore and to faithfully capture the behavior of the system, we use PID controllers with varied control coefficients to generate the control-pressure trajectory data for simulator training. Further, we add random deviations to the PID controllers to capture the dynamics more robustly.

We collect data for training by running mechanical ventilation tasks on a physical test lung using an open-source ventilator designed by Princeton University's People's Ventilator Project. We built a ventilator farm housing ten ventilator-lung systems on a server rack, which captures multiple airway resistance and compliance settings that span a spectrum of patient lung conditions, as required for practical applications of ventilator systems.

We use a rack-based ventilator farm (10 ventilators / artificial lungs) to collect training data for a ventilator-lung simulator. Using this simulator, we train a DNN controller that we then validate on the physical ventilator farm.

The true underlying state of the dynamical system is not available to the model directly, but rather only through observations of the airway pressure in the system. In the simulator we model the state of the system at any time as a collection of previous pressure observations and the control actions applied to the system (up to a limited lookback window). These inputs are fed into a DNN that predicts the subsequent pressure in the system. We train this simulator on the control-pressure trajectory data collected through interactions with the test lung.

The performance of the simulator is measured via the sum of deviations of the simulator’s predictions (under self-simulation) from the ground truth.

While it is infeasible to compare real dynamics with their simulated counterparts over all possible trajectories and control inputs, we measure the distance between simulation and the known safe trajectories. We introduce some random exploration around these safe trajectories for robustness.

Having learned an accurate simulator, we then use it to train a DNN-based controller completely offline. This approach allows us to rapidly apply updates during controller training. Furthermore, the differentiable nature of the simulator allows for the stable use of the direct policy gradient, where we analytically compute the gradient of the loss with respect to the DNN parameters.  We find this method to be significantly more efficient than model-free approaches.

To establish a baseline, we run an exhaustive grid of PID controllers for multiple lung settings and select the best performing PID controller as measured by average absolute deviation between the desired pressure waveform and the actual pressure waveform. We compare these to our controllers and provide evidence that our DNN controllers are better performing and more robust.

  1. Breathing waveform tracking performance:

    We compare the best PID controller for a given lung setting against our controller trained on the learned simulator for the same setting. Our learned controller shows a 22% lower mean absolute error (MAE) between target and actual pressure waveforms.

    Comparison of the MAE between target and actual pressure waveforms (lower is better) for the best PID controller (orange) for a given lung setting (shown for two settings, R=5 and R=20) against our controller (blue) trained on the learned simulator for the same setting. The learned controller performs up to 22% better.
  2. Robustness:

    Further, we compare the performance of the single best PID controller across the entire set of lung settings with our controller trained on a set of learned simulators over the same settings. Our controller performs up to 32% better in MAE between target and actual pressure waveforms, suggesting that it could require less manual intervention between patients or even as a patient's condition changes.

    As above, but comparing the single best PID controller across the entire set of lung settings against our controller trained over the same settings. The learned controller performs up to 32% better, suggesting that it may require less manual intervention.

Finally, we investigated the feasibility of using model-free and other popular RL algorithms (PPO, DQN), in comparison to a direct policy gradient trained on the simulator. We find that the simulator-trained direct policy gradient achieves slightly better scores and does so with a more stable training process that uses orders of magnitude fewer training samples and a significantly smaller hyperparameter search space.

In the simulator, we find that model-free and other popular algorithms (PPO, DQN) perform approximately as well as our method.
However, these other methods take an order of magnitude more episodes to train to similar levels.

Conclusions and the Road Forward
We have described a deep-learning approach to mechanical ventilation based on simulated dynamics learned from a physical test lung. However, this is only the beginning. To make an impact on real-world ventilators there are numerous other considerations and issues to take into account. Most important amongst them are non-invasive ventilators, which are significantly more challenging due to the difficulty of discerning pressure from lungs and mask pressure. Other directions are how to handle spontaneous breathing and coughing. To learn more and become involved in this important intersection of machine learning and health, see an ICML tutorial on control theory and learning, and consider participating in one of our kaggle competitions for creating better ventilator simulators!

The primary work was based in the Google AI Princeton lab, in collaboration with Cohen lab at the Mechanical and Aerospace Engineering department at Princeton University. The research paper was authored by contributors from Google and Princeton University, including: Daniel Suo, Naman Agarwal, Wenhan Xia, Xinyi Chen, Udaya Ghai, Alexander Yu, Paula Gradu, Karan Singh, Cyril Zhang, Edgar Minasyan, Julienne LaChance, Tom Zajdel, Manuel Schottdorf, Daniel Cohen, and Elad Hazan.

Source: Google AI Blog

Predicting Text Selections with Federated Learning

Smart Text Selection, launched in 2017 as part of Android O, is one of Android’s most frequently used features, helping users select, copy, and use text easily and quickly by predicting the desired word or set of words around a user’s tap, and automatically expanding the selection appropriately. Through this feature, selections are automatically expanded, and for selections with defined classification types, e.g., addresses and phone numbers, users are offered an app with which to open the selection, saving users even more time.

Today we describe how we have improved the performance of Smart Text Selection by using federated learning to train the neural network model on user interactions responsibly while preserving user privacy. This work, which is part of Android’s new Private Compute Core secure environment, enabled us to improve the model’s selection accuracy by up to 20% on some types of entities.

Server-Side Proxy Data for Entity Selections
Smart Text Selection, which is the same technology behind Smart Linkify, does not predict arbitrary selections, but focuses on well-defined entities, such as addresses or phone numbers, and tries to predict the selection bounds for those categories. In the absence of multi-word entities, the model is trained to only select a single word in order to minimize the frequency of making multi-word selections in error.

The Smart Text Selection feature was originally trained using proxy data sourced from web pages to which schema.org annotations had been applied. These entities were then embedded in a selection of random text, and the model was trained to select just the entity, without spilling over into the random text surrounding it.

While this approach of training on schema.org-annotations worked, it had several limitations. The data was quite different from text that we expect users see on-device. For example, websites with schema.org annotations typically have entities with more proper formatting than what users might type on their phones. In addition, the text samples in which the entities were embedded for training were random and did not reflect realistic context on-device.

On-Device Feedback Signal for Federated Learning
With this new launch, the model no longer uses proxy data for span prediction, but is instead trained on-device on real interactions using federated learning. This is a training approach for machine learning models in which a central server coordinates model training that is split among many devices, while the raw data used stays on the local device. A standard federated learning training process works as follows: The server starts by initializing the model. Then, an iterative process begins in which (a) devices get sampled, (b) selected devices improve the model using their local data, and (c) then send back only the improved model, not the data used for training. The server then averages the updates it received to create the model that is sent out in the next iteration.

For Smart Text Selection, each time a user taps to select text and corrects the model’s suggestion, Android gets precise feedback for what selection span the model should have predicted. In order to preserve user privacy, the selections are temporarily kept on the device, without being visible server-side, and are then used to improve the model by applying federated learning techniques. This technique has the advantage of training the model on the same kind of data that it sees during inference.

Federated Learning & Privacy
One of the advantages of the federated learning approach is that it enables user privacy, because raw data is not exposed to a server. Instead, the server only receives updated model weights. Still, to protect against various threats, we explored ways to protect the on-device data, securely aggregate gradients, and reduce the risk of model memorization.

The on-device code for training Federated Smart Text Selection models is part of Android’s Private Compute Core secure environment, which makes it particularly well situated to securely handle user data. This is because the training environment in Private Compute Core is isolated from the network and data egress is only allowed when federated and other privacy-preserving techniques are applied. In addition to network isolation, data in Private Compute Core is protected by policies that restrict how it can be used, thus protecting from malicious code that may have found its way onto the device.

To aggregate model updates produced by the on-device training code, we use Secure Aggregation, a cryptographic protocol that allows servers to compute the mean update for federated learning model training without reading the updates provided by individual devices. In addition to being individually protected by Secure Aggregation, the updates are also protected by transport encryption, creating two layers of defense against attackers on the network.

Finally, we looked into model memorization. In principle, it is possible for characteristics of the training data to be encoded in the updates sent to the server, survive the aggregation process, and end up being memorized by the global model. This could make it possible for an attacker to attempt to reconstruct the training data from the model. We used methods from Secret Sharer, an analysis technique that quantifies to what degree a model unintentionally memorizes its training data, to empirically verify that the model was not memorizing sensitive information. Further, we employed data masking techniques to prevent certain kinds of sensitive data from ever being seen by the model

In combination, these techniques help ensure that Federated Smart Text Selection is trained in a way that preserves user privacy.

Achieving Superior Model Quality
Initial attempts to train the model using federated learning were unsuccessful. The loss did not converge and predictions were essentially random. Debugging the training process was difficult, because the training data was on-device and not centrally collected, and so, it could not be examined or verified. In fact, in such a case, it’s not even possible to determine if the data looks as expected, which is often the first step in debugging machine learning pipelines.

To overcome this challenge, we carefully designed high-level metrics that gave us an understanding of how the model behaved during training. Such metrics included the number of training examples, selection accuracy, and recall and precision metrics for each entity type. These metrics are collected during federated training via federated analytics, a similar process as the collection of the model weights. Through these metrics and many analyses, we were able to better understand which aspects of the system worked well and where bugs could exist.

After fixing these bugs and making additional improvements, such as implementing on-device filters for data, using better federated optimization methods and applying more robust gradient aggregators, the model trained nicely.

Using this new federated approach, we were able to significantly improve Smart Text Selection models, with the degree depending on the language being used. Typical improvements ranged between 5% and 7% for multi-word selection accuracy, with no drop in single-word performance. The accuracy of correctly selecting addresses (the most complex type of entity supported) increased by between 8% and 20%, again, depending on the language being used. These improvements lead to millions of additional selections being automatically expanded for users every day.

An additional advantage of this federated learning approach for Smart Text Selection is its ability to scale to additional languages. Server-side training required manual tweaking of the proxy data for each language in order to make it more similar to on-device data. While this only works to some degree, it takes a tremendous amount of effort for each additional language.

The federated learning pipeline, however, trains on user interactions, without the need for such manual adjustments. Once the model achieved good results for English, we applied the same pipeline to Japanese and saw even greater improvements, without needing to tune the system specifically for Japanese selections.

We hope that this new federated approach lets us scale Smart Text Selection to many more languages. Ideally this will also work without manual tuning of the system, making it possible to support even low-resource languages.

We developed a federated way of learning to predict text selections based on user interactions, resulting in much improved Smart Text Selection models deployed to Android users. This approach required the use of federated learning, since it works without collecting user data on the server. Additionally, we used many state-of-the-art privacy approaches, such as Android's new Private Compute Core, Secure Aggregation and the Secret Sharer method. The results show that privacy does not have to be a limiting factor when training models. Instead, we managed to obtain a significantly better model, while ensuring that users' data stays private.

Many people contributed to this work. We would like to thank Lukas Zilka, Asela Gunawardana, Silvano Bonacina, Seth Welna, Tony Mak, Chang Li, Abodunrinwa Toki, Sergey Volnov, Matt Sharifi, Abhanshu Sharma, Eugenio Marchiori, Jacek Jurewicz, Nicholas Carlini, Jordan McClead, Sophia Kovaleva, Evelyn Kao, Tom Hume, Alex Ingerman, Brendan McMahan, Fei Zheng, Zachary Charles, Sean Augenstein, Zachary Garrett, Stefan Dierauf, David Petrou, Vishwath Mohan, Hunter King, Emily Glanz, Hubert Eichner, Krzysztof Ostrowski, Jakub Konecny, Shanshan Wu, Janel Thamkul, Elizabeth Kemp, and everyone else involved in the project.

Source: Google AI Blog