Tag Archives: Research

DP-Auditorium: A flexible library for auditing differential privacy

Differential privacy (DP) is a property of randomized mechanisms that limit the influence of any individual user’s information while processing and analyzing data. DP offers a robust solution to address growing concerns about data protection, enabling technologies across industries and government applications (e.g., the US census) without compromising individual user identities. As its adoption increases, it’s important to identify the potential risks of developing mechanisms with faulty implementations. Researchers have recently found errors in the mathematical proofs of private mechanisms, and their implementations. For example, researchers compared six sparse vector technique (SVT) variations and found that only two of the six actually met the asserted privacy guarantee. Even when mathematical proofs are correct, the code implementing the mechanism is vulnerable to human error.

However, practical and efficient DP auditing is challenging primarily due to the inherent randomness of the mechanisms and the probabilistic nature of the tested guarantees. In addition, a range of guarantee types exist, (e.g., pure DP, approximate DP, Rényi DP, and concentrated DP), and this diversity contributes to the complexity of formulating the auditing problem. Further, debugging mathematical proofs and code bases is an intractable task given the volume of proposed mechanisms. While ad hoc testing techniques exist under specific assumptions of mechanisms, few efforts have been made to develop an extensible tool for testing DP mechanisms.

To that end, in “DP-Auditorium: A Large Scale Library for Auditing Differential Privacy”, we introduce an open source library for auditing DP guarantees with only black-box access to a mechanism (i.e., without any knowledge of the mechanism’s internal properties). DP-Auditorium is implemented in Python and provides a flexible interface that allows contributions to continuously improve its testing capabilities. We also introduce new testing algorithms that perform divergence optimization over function spaces for Rényi DP, pure DP, and approximate DP. We demonstrate that DP-Auditorium can efficiently identify DP guarantee violations, and suggest which tests are most suitable for detecting particular bugs under various privacy guarantees.


DP guarantees

The output of a DP mechanism is a sample drawn from a probability distribution (M (D)) that satisfies a mathematical property ensuring the privacy of user data. A DP guarantee is thus tightly related to properties between pairs of probability distributions. A mechanism is differentially private if the probability distributions determined by M on dataset D and a neighboring dataset D’, which differ by only one record, are indistinguishable under a given divergence metric.

For example, the classical approximate DP definition states that a mechanism is approximately DP with parameters (ε, δ) if the hockey-stick divergence of order eε, between M(D) and M(D’), is at most δ. Pure DP is a special instance of approximate DP where δ = 0. Finally, a mechanism is considered Rényi DP with parameters (𝛼, ε) if the Rényi divergence of order 𝛼, is at most ε (where ε is a small positive value). In these three definitions, ε is not interchangeable but intuitively conveys the same concept; larger values of ε imply larger divergences between the two distributions or less privacy, since the two distributions are easier to distinguish.


DP-Auditorium

DP-Auditorium comprises two main components: property testers and dataset finders. Property testers take samples from a mechanism evaluated on specific datasets as input and aim to identify privacy guarantee violations in the provided datasets. Dataset finders suggest datasets where the privacy guarantee may fail. By combining both components, DP-Auditorium enables (1) automated testing of diverse mechanisms and privacy definitions and, (2) detection of bugs in privacy-preserving mechanisms. We implement various private and non-private mechanisms, including simple mechanisms that compute the mean of records and more complex mechanisms, such as different SVT and gradient descent mechanism variants.

Property testers determine if evidence exists to reject the hypothesis that a given divergence between two probability distributions, P and Q, is bounded by a prespecified budget determined by the DP guarantee being tested. They compute a lower bound from samples from P and Q, rejecting the property if the lower bound value exceeds the expected divergence. No guarantees are provided if the result is indeed bounded. To test for a range of privacy guarantees, DP-Auditorium introduces three novel testers: (1) HockeyStickPropertyTester, (2) RényiPropertyTester, and (3) MMDPropertyTester. Unlike other approaches, these testers don’t depend on explicit histogram approximations of the tested distributions. They rely on variational representations of the hockey-stick divergence, Rényi divergence, and maximum mean discrepancy (MMD) that enable the estimation of divergences through optimization over function spaces. As a baseline, we implement HistogramPropertyTester, a commonly used approximate DP tester. While our three testers follow a similar approach, for brevity, we focus on the HockeyStickPropertyTester in this post.

Given two neighboring datasets, D and D’, the HockeyStickPropertyTester finds a lower bound,^δ  for the hockey-stick divergence between M(D) and M(D’) that holds with high probability. Hockey-stick divergence enforces that the two distributions M(D) and M(D’) are close under an approximate DP guarantee. Therefore, if a privacy guarantee claims that the hockey-stick divergence is at most δ, and^δ  > δ, then with high probability the divergence is higher than what was promised on D and D’ and the mechanism cannot satisfy the given approximate DP guarantee. The lower bound^δ  is computed as an empirical and tractable counterpart of a variational formulation of the hockey-stick divergence (see the paper for more details). The accuracy of^δ  increases with the number of samples drawn from the mechanism, but decreases as the variational formulation is simplified. We balance these factors in order to ensure that^δ  is both accurate and easy to compute.

Dataset finders use black-box optimization to find datasets D and D’ that maximize^δ, a lower bound on the divergence value δ. Note that black-box optimization techniques are specifically designed for settings where deriving gradients for an objective function may be impractical or even impossible. These optimization techniques oscillate between exploration and exploitation phases to estimate the shape of the objective function and predict areas where the objective can have optimal values. In contrast, a full exploration algorithm, such as the grid search method, searches over the full space of neighboring datasets D and D’. DP-Auditorium implements different dataset finders through the open sourced black-box optimization library Vizier.

Running existing components on a new mechanism only requires defining the mechanism as a Python function that takes an array of data D and a desired number of samples n to be output by the mechanism computed on D. In addition, we provide flexible wrappers for testers and dataset finders that allow practitioners to implement their own testing and dataset search algorithms.


Key results

We assess the effectiveness of DP-Auditorium on five private and nine non-private mechanisms with diverse output spaces. For each property tester, we repeat the test ten times on fixed datasets using different values of ε, and report the number of times each tester identifies privacy bugs. While no tester consistently outperforms the others, we identify bugs that would be missed by previous techniques (HistogramPropertyTester). Note that the HistogramPropertyTester is not applicable to SVT mechanisms.

Number of times each property tester finds the privacy violation for the tested non-private mechanisms. NonDPLaplaceMean and NonDPGaussianMean mechanisms are faulty implementations of the Laplace and Gaussian mechanisms for computing the mean.

We also analyze the implementation of a DP gradient descent algorithm (DP-GD) in TensorFlow that computes gradients of the loss function on private data. To preserve privacy, DP-GD employs a clipping mechanism to bound the l2-norm of the gradients by a value G, followed by the addition of Gaussian noise. This implementation incorrectly assumes that the noise added has a scale of G, while in reality, the scale is sG, where s is a positive scalar. This discrepancy leads to an approximate DP guarantee that holds only for values of s greater than or equal to 1.

We evaluate the effectiveness of property testers in detecting this bug and show that HockeyStickPropertyTester and RényiPropertyTester exhibit superior performance in identifying privacy violations, outperforming MMDPropertyTester and HistogramPropertyTester. Notably, these testers detect the bug even for values of s as high as 0.6. It is worth highlighting that s = 0.5 corresponds to a common error in literature that involves missing a factor of two when accounting for the privacy budget ε. DP-Auditorium successfully captures this bug as shown below. For more details see section 5.6 here.

Estimated divergences and test thresholds for different values of s when testing DP-GD with the HistogramPropertyTester (left) and the HockeyStickPropertyTester (right).

Estimated divergences and test thresholds for different values of s when testing DP-GD with the RényiPropertyTester (left) and the MMDPropertyTester (right)

To test dataset finders, we compute the number of datasets explored before finding a privacy violation. On average, the majority of bugs are discovered in less than 10 calls to dataset finders. Randomized and exploration/exploitation methods are more efficient at finding datasets than grid search. For more details, see the paper.


Conclusion

DP is one of the most powerful frameworks for data protection. However, proper implementation of DP mechanisms can be challenging and prone to errors that cannot be easily detected using traditional unit testing methods. A unified testing framework can help auditors, regulators, and academics ensure that private mechanisms are indeed private.

DP-Auditorium is a new approach to testing DP via divergence optimization over function spaces. Our results show that this type of function-based estimation consistently outperforms previous black-box access testers. Finally, we demonstrate that these function-based estimators allow for a better discovery rate of privacy bugs compared to histogram estimation. By open sourcing DP-Auditorium, we aim to establish a standard for end-to-end testing of new differentially private algorithms.


Acknowledgements

The work described here was done jointly with Andrés Muñoz Medina, William Kong and Umar Syed. We thank Chris Dibak and Vadym Doroshenko for helpful engineering support and interface suggestions for our library.

Source: Google AI Blog


Graph neural networks in TensorFlow

Objects and their relationships are ubiquitous in the world around us, and relationships can be as important to understanding an object as its own attributes viewed in isolation — take for example transportation networks, production networks, knowledge graphs, or social networks. Discrete mathematics and computer science have a long history of formalizing such networks as graphs, consisting of nodes connected by edges in various irregular ways. Yet most machine learning (ML) algorithms allow only for regular and uniform relations between input objects, such as a grid of pixels, a sequence of words, or no relation at all.

Graph neural networks, or GNNs for short, have emerged as a powerful technique to leverage both the graph’s connectivity (as in the older algorithms DeepWalk and Node2Vec) and the input features on the various nodes and edges. GNNs can make predictions for graphs as a whole (Does this molecule react in a certain way?), for individual nodes (What’s the topic of this document, given its citations?) or for potential edges (Is this product likely to be purchased together with that product?). Apart from making predictions about graphs, GNNs are a powerful tool used to bridge the chasm to more typical neural network use cases. They encode a graph's discrete, relational information in a continuous way so that it can be included naturally in another deep learning system.

We are excited to announce the release of TensorFlow GNN 1.0 (TF-GNN), a production-tested library for building GNNs at large scales. It supports both modeling and training in TensorFlow as well as the extraction of input graphs from huge data stores. TF-GNN is built from the ground up for heterogeneous graphs, where types of objects and relations are represented by distinct sets of nodes and edges. Real-world objects and their relations occur in distinct types, and TF-GNN's heterogeneous focus makes it natural to represent them.

Inside TensorFlow, such graphs are represented by objects of type tfgnn.GraphTensor. This is a composite tensor type (a collection of tensors in one Python class) accepted as a first-class citizen in tf.data.Dataset, tf.function, etc. It stores both the graph structure and its features attached to nodes, edges and the graph as a whole. Trainable transformations of GraphTensors can be defined as Layers objects in the high-level Keras API, or directly using the tfgnn.GraphTensor primitive.


GNNs: Making predictions for an object in context

For illustration, let’s look at one typical application of TF-GNN: predicting a property of a certain type of node in a graph defined by cross-referencing tables of a huge database. For example, a citation database of Computer Science (CS) arXiv papers with one-to-many cites and many-to-one cited relationships where we would like to predict the subject area of each paper.

Like most neural networks, a GNN is trained on a dataset of many labeled examples (~millions), but each training step consists only of a much smaller batch of training examples (say, hundreds). To scale to millions, the GNN gets trained on a stream of reasonably small subgraphs from the underlying graph. Each subgraph contains enough of the original data to compute the GNN result for the labeled node at its center and train the model. This process — typically referred to as subgraph sampling — is extremely consequential for GNN training. Most existing tooling accomplishes sampling in a batch way, producing static subgraphs for training. TF-GNN provides tooling to improve on this by sampling dynamically and interactively.

Pictured, the process of subgraph sampling where small, tractable subgraphs are sampled from a larger graph to create input examples for GNN training.

TF-GNN 1.0 debuts a flexible Python API to configure dynamic or batch subgraph sampling at all relevant scales: interactively in a Colab notebook (like this one), for efficient sampling of a small dataset stored in the main memory of a single training host, or distributed by Apache Beam for huge datasets stored on a network filesystem (up to hundreds of millions of nodes and billions of edges). For details, please refer to our user guides for in-memory and beam-based sampling, respectively.

On those same sampled subgraphs, the GNN’s task is to compute a hidden (or latent) state at the root node; the hidden state aggregates and encodes the relevant information of the root node's neighborhood. One classical approach is message-passing neural networks. In each round of message passing, nodes receive messages from their neighbors along incoming edges and update their own hidden state from them. After n rounds, the hidden state of the root node reflects the aggregate information from all nodes within n edges (pictured below for n = 2). The messages and the new hidden states are computed by hidden layers of the neural network. In a heterogeneous graph, it often makes sense to use separately trained hidden layers for the different types of nodes and edges

Pictured, a simple message-passing neural network where, at each step, the node state is propagated from outer to inner nodes where it is pooled to compute new node states. Once the root node is reached, a final prediction can be made.

The training setup is completed by placing an output layer on top of the GNN’s hidden state for the labeled nodes, computing the loss (to measure the prediction error), and updating model weights by backpropagation, as usual in any neural network training.

Beyond supervised training (i.e., minimizing a loss defined by labels), GNNs can also be trained in an unsupervised way (i.e., without labels). This lets us compute a continuous representation (or embedding) of the discrete graph structure of nodes and their features. These representations are then typically utilized in other ML systems. In this way, the discrete, relational information encoded by a graph can be included in more typical neural network use cases. TF-GNN supports a fine-grained specification of unsupervised objectives for heterogeneous graphs.


Building GNN architectures

The TF-GNN library supports building and training GNNs at various levels of abstraction.

At the highest level, users can take any of the predefined models bundled with the library that are expressed in Keras layers. Besides a small collection of models from the research literature, TF-GNN comes with a highly configurable model template that provides a curated selection of modeling choices that we have found to provide strong baselines on many of our in-house problems. The templates implement GNN layers; users need only to initialize the Keras layers.

At the lowest level, users can write a GNN model from scratch in terms of primitives for passing data around the graph, such as broadcasting data from a node to all its outgoing edges or pooling data into a node from all its incoming edges (e.g., computing the sum of incoming messages). TF-GNN’s graph data model treats nodes, edges and whole input graphs equally when it comes to features or hidden states, making it straightforward to express not only node-centric models like the MPNN discussed above but also more general forms of GraphNets. This can, but need not, be done with Keras as a modeling framework on the top of core TensorFlow. For more details, and intermediate levels of modeling, see the TF-GNN user guide and model collection.


Training orchestration

While advanced users are free to do custom model training, the TF-GNN Runner also provides a succinct way to orchestrate the training of Keras models in the common cases. A simple invocation may look like this:

The Runner provides ready-to-use solutions for ML pains like distributed training and tfgnn.GraphTensor padding for fixed shapes on Cloud TPUs. Beyond training on a single task (as shown above), it supports joint training on multiple (two or more) tasks in concert. For example, unsupervised tasks can be mixed with supervised ones to inform a final continuous representation (or embedding) with application specific inductive biases. Callers only need substitute the task argument with a mapping of tasks:

Additionally, the TF-GNN Runner also includes an implementation of integrated gradients for use in model attribution. Integrated gradients output is a GraphTensor with the same connectivity as the observed GraphTensor but its features replaced with gradient values where larger values contribute more than smaller values in the GNN prediction. Users can inspect gradient values to see which features their GNN uses the most.


Conclusion

In short, we hope TF-GNN will be useful to advance the application of GNNs in TensorFlow at scale and fuel further innovation in the field. If you’re curious to find out more, please try our Colab demo with the popular OGBN-MAG benchmark (in your browser, no installation required), browse the rest of our user guides and Colabs, or take a look at our paper.


Acknowledgements

The TF-GNN release 1.0 was developed by a collaboration between Google Research: Sami Abu-El-Haija, Neslihan Bulut, Bahar Fatemi, Johannes Gasteiger, Pedro Gonnet, Jonathan Halcrow, Liangze Jiang, Silvio Lattanzi, Brandon Mayer, Vahab Mirrokni, Bryan Perozzi, Anton Tsitsulin, Dustin Zelle, Google Core ML: Arno Eigenwillig, Oleksandr Ferludin, Parth Kothari, Mihir Paradkar, Jan Pfeifer, Rachael Tamakloe, and Google DeepMind: Alvaro Sanchez-Gonzalez and Lisa Wang.

Source: Google AI Blog


A decoder-only foundation model for time-series forecasting

Time-series forecasting is ubiquitous in various domains, such as retail, finance, manufacturing, healthcare and natural sciences. In retail use cases, for example, it has been observed that improving demand forecasting accuracy can meaningfully reduce inventory costs and increase revenue. Deep learning (DL) models have emerged as a popular approach for forecasting rich, multivariate, time-series data because they have proven to perform well in a variety of settings (e.g., DL models dominated the M5 competition leaderboard).

At the same time, there has been rapid progress in large foundation language models used for natural language processing (NLP) tasks, such as translation, retrieval-augmented generation, and code completion. These models are trained on massive amounts of textual data derived from a variety of sources like common crawl and open-source code that allows them to identify patterns in languages. This makes them very powerful zero-shot tools; for instance, when paired with retrieval, they can answer questions about and summarize current events.

Despite DL-based forecasters largely outperforming traditional methods and progress being made in reducing training and inference costs, they face challenges: most DL architectures require long and involved training and validation cycles before a customer can test the model on a new time-series. A foundation model for time-series forecasting, in contrast, can provide decent out-of-the-box forecasts on unseen time-series data with no additional training, enabling users to focus on refining forecasts for the actual downstream task like retail demand planning.

To that end, in “A decoder-only foundation model for time-series forecasting”, we introduce TimesFM, a single forecasting model pre-trained on a large time-series corpus of 100 billion real world time-points. Compared to the latest large language models (LLMs), TimesFM is much smaller (200M parameters), yet we show that even at such scales, its zero-shot performance on a variety of unseen datasets of different domains and temporal granularities come close to the state-of-the-art supervised approaches trained explicitly on these datasets. Later this year we plan to make this model available for external customers in Google Cloud Vertex AI.


A decoder-only foundation model for time-series forecasting

LLMs are usually trained in a decoder-only fashion that involves three steps. First, text is broken down into subwords called tokens. Then, the tokens are fed into stacked causal transformer layers that produce an output corresponding to each input token (it cannot attend to future tokens). Finally, the output corresponding to the i-th token summarizes all the information from previous tokens and predicts the (i+1)-th token. During inference, the LLM generates the output one token at a time. For example, when prompted with “What is the capital of France?”, it might generate the token “The”, then condition on “What is the capital of France? The” to generate the next token “capital” and so on until it generates the complete answer: “The capital of France is Paris”.

A foundation model for time-series forecasting should adapt to variable context (what we observe) and horizon (what we query the model to forecast) lengths, while having enough capacity to encode all patterns from a large pretraining dataset. Similar to LLMs, we use stacked transformer layers (self-attention and feedforward layers) as the main building blocks for the TimesFM model. In the context of time-series forecasting, we treat a patch (a group of contiguous time-points) as a token that was popularized by a recent long-horizon forecasting work. The task then is to forecast the (i+1)-th patch of time-points given the i-th output at the end of the stacked transformer layers.

However, there are several key differences from language models. Firstly, we need a multilayer perceptron block with residual connections to convert a patch of time-series into a token that can be input to the transformer layers along with positional encodings (PE). For that, we use a residual block similar to our prior work in long-horizon forecasting. Secondly, at the other end, an output token from the stacked transformer can be used to predict a longer length of subsequent time-points than the input patch length, i.e., the output patch length can be larger than the input patch length.

Consider a time-series of length 512 time-points being used to train a TimesFM model with input patch length 32 and output patch length 128. During training, the model is simultaneously trained to use the first 32 time-points to forecast the next 128 time-points, the first 64 time-points to forecast time-points 65 to 192, the first 96 time-points to forecast time-points 97 to 224 and so on. During inference, suppose the model is given a new time-series of length 256 and tasked with forecasting the next 256 time-points into the future. The model will first generate the future predictions for time-points 257 to 384, then condition on the initial 256 length input plus the generated output to generate time-points 385 to 512. On the other hand, if in our model the output patch length was equal to the input patch length of 32 then for the same task we would have to go through eight generation steps instead of just the two above. This increases the chances of more errors accumulating and therefore, in practice, we see that a longer output patch length yields better performance for long-horizon forecasting

TimesFM architecture.


Pretraining data

Just like LLMs get better with more tokens, TimesFM requires a large volume of legitimate time series data to learn and improve. We have spent a great amount of time creating and assessing our training datasets, and the following is what we have found works best:

Synthetic data helps with the basics. Meaningful synthetic time-series data can be generated using statistical models or physical simulations. These basic temporal patterns can teach the model the grammar of time series forecasting.

Real-world data adds real-world flavor. We comb through available public time series datasets, and selectively put together a large corpus of 100 billion time-points. Among these datasets there are Google Trends and Wikipedia Pageviews, which track what people are interested in, and that nicely mirrors trends and patterns in many other real-world time series. This helps TimesFM understand the bigger picture and generalize better when provided with domain-specific contexts not seen during training.


Zero-shot evaluation results

We evaluate TimesFM zero-shot on data not seen during training using popular time-series benchmarks. We observe that TimesFM performs better than most statistical methods like ARIMA, ETS and can match or outperform powerful DL models like DeepAR, PatchTST that have been explicitly trained on the target time-series.

We used the Monash Forecasting Archive to evaluate TimesFM’s out-of-the-box performance. This archive contains tens of thousands of time-series from various domains like traffic, weather, and demand forecasting covering frequencies ranging from few minutes to yearly data. Following existing literature, we inspect the mean absolute error (MAE) appropriately scaled so that it can be averaged across the datasets. We see that zero-shot (ZS) TimesFM is better than most supervised approaches, including recent deep learning models. We also compare TimesFM to GPT-3.5 for forecasting using a specific prompting technique proposed by llmtime(ZS). We demonstrate that TimesFM performs better than llmtime(ZS) despite being orders of magnitude smaller.

Scaled MAE (the lower the better) of TimesFM(ZS) against other supervised and zero-shot approaches on Monash datasets.

Most of the Monash datasets are short or medium horizon, i.e., the prediction length is not too long. We also test TimesFM on popular benchmarks for long horizon forecasting against a recent state-of-the-art baseline PatchTST (and other long-horizon forecasting baselines). In the next figure, we plot the MAE on ETT datasets for the task of predicting 96 and 192 time-points into the future. The metric has been calculated on the last test window of each dataset (as done by the llmtime paper). We see that TimesFM not only surpasses the performance of llmtime(ZS) but also matches that of the supervised PatchTST model explicitly trained on the respective datasets.

Last window MAE (the lower the better) of TimesFM(ZS) against llmtime(ZS) and long-horizon forecasting baselines on ETT datasets.


Conclusion

We train a decoder-only foundation model for time-series forecasting using a large pretraining corpus of 100B real world time-points, the majority of which was search interest time-series data derived from Google Trends and pageviews from Wikipedia. We show that even a relatively small 200M parameter pretrained model that uses our TimesFM architecture displays impressive zero-shot performance on a variety of public benchmarks from different domains and granularities.


Acknowledgements

This work is the result of a collaboration between several individuals across Google Research and Google Cloud, including (in alphabetical order): Abhimanyu Das, Weihao Kong, Andrew Leach, Mike Lawrence, Alex Martin, Rajat Sen, Yang Yang and Yichen Zhou.

Source: Google AI Blog


Intervening on early readouts for mitigating spurious features and simplicity bias

Machine learning models in the real world are often trained on limited data that may contain unintended statistical biases. For example, in the CELEBA celebrity image dataset, a disproportionate number of female celebrities have blond hair, leading to classifiers incorrectly predicting “blond” as the hair color for most female faces — here, gender is a spurious feature for predicting hair color. Such unfair biases could have significant consequences in critical applications such as medical diagnosis.

Surprisingly, recent work has also discovered an inherent tendency of deep networks to amplify such statistical biases, through the so-called simplicity bias of deep learning. This bias is the tendency of deep networks to identify weakly predictive features early in the training, and continue to anchor on these features, failing to identify more complex and potentially more accurate features.

With the above in mind, we propose simple and effective fixes to this dual challenge of spurious features and simplicity bias by applying early readouts and feature forgetting. First, in “Using Early Readouts to Mediate Featural Bias in Distillation”, we show that making predictions from early layers of a deep network (referred to as “early readouts”) can automatically signal issues with the quality of the learned representations. In particular, these predictions are more often wrong, and more confidently wrong, when the network is relying on spurious features. We use this erroneous confidence to improve outcomes in model distillation, a setting where a larger “teacher” model guides the training of a smaller “student” model. Then in “Overcoming Simplicity Bias in Deep Networks using a Feature Sieve”, we intervene directly on these indicator signals by making the network “forget” the problematic features and consequently look for better, more predictive features. This substantially improves the model’s ability to generalize to unseen domains compared to previous approaches. Our AI Principles and our Responsible AI practices guide how we research and develop these advanced applications and help us address the challenges posed by statistical biases.

Animation comparing hypothetical responses from two models trained with and without the feature sieve.

Early readouts for debiasing distillation

We first illustrate the diagnostic value of early readouts and their application in debiased distillation, i.e., making sure that the student model inherits the teacher model’s resilience to feature bias through distillation. We start with a standard distillation framework where the student is trained with a mixture of label matching (minimizing the cross-entropy loss between student outputs and the ground-truth labels) and teacher matching (minimizing the KL divergence loss between student and teacher outputs for any given input).

Suppose one trains a linear decoder, i.e., a small auxiliary neural network named as Aux, on top of an intermediate representation of the student model. We refer to the output of this linear decoder as an early readout of the network representation. Our finding is that early readouts make more errors on instances that contain spurious features, and further, the confidence on those errors is higher than the confidence associated with other errors. This suggests that confidence on errors from early readouts is a fairly strong, automated indicator of the model’s dependence on potentially spurious features.

Illustrating the usage of early readouts (i.e., output from the auxiliary layer) in debiasing distillation. Instances that are confidently mispredicted in the early readouts are upweighted in the distillation loss.

We used this signal to modulate the contribution of the teacher in the distillation loss on a per-instance basis, and found significant improvements in the trained student model as a result.

We evaluated our approach on standard benchmark datasets known to contain spurious correlations (Waterbirds, CelebA, CivilComments, MNLI). Each of these datasets contain groupings of data that share an attribute potentially correlated with the label in a spurious manner. As an example, the CelebA dataset mentioned above includes groups such as {blond male, blond female, non-blond male, non-blond female}, with models typically performing the worst on the {non-blond female} group when predicting hair color. Thus, a measure of model performance is its worst group accuracy, i.e., the lowest accuracy among all known groups present in the dataset. We improved the worst group accuracy of student models on all datasets; moreover, we also improved overall accuracy in three of the four datasets, showing that our improvement on any one group does not come at the expense of accuracy on other groups. More details are available in our paper.

Comparison of Worst Group Accuracies of different distillation techniques relative to that of the Teacher model. Our method outperforms other methods on all datasets.

Overcoming simplicity bias with a feature sieve

In a second, closely related project, we intervene directly on the information provided by early readouts, to improve feature learning and generalization. The workflow alternates between identifying problematic features and erasing identified features from the network. Our primary hypothesis is that early features are more prone to simplicity bias, and that by erasing (“sieving”) these features, we allow richer feature representations to be learned.

Training workflow with feature sieve. We alternate between identifying problematic features (using training iteration) and erasing them from the network (using forgetting iteration).

We describe the identification and erasure steps in more detail:

  • Identifying simple features: We train the primary model and the readout model (AUX above) in conventional fashion via forward- and back-propagation. Note that feedback from the auxiliary layer does not back-propagate to the main network. This is to force the auxiliary layer to learn from already-available features rather than create or reinforce them in the main network.
  • Applying the feature sieve: We aim to erase the identified features in the early layers of the neural network with the use of a novel forgetting loss, Lf , which is simply the cross-entropy between the readout and a uniform distribution over labels. Essentially, all information that leads to nontrivial readouts are erased from the primary network. In this step, the auxiliary network and upper layers of the main network are kept unchanged.

We can control specifically how the feature sieve is applied to a given dataset through a small number of configuration parameters. By changing the position and complexity of the auxiliary network, we control the complexity of the identified- and erased features. By modifying the mixing of learning and forgetting steps, we control the degree to which the model is challenged to learn more complex features. These choices, which are dataset-dependent, are made via hyperparameter search to maximize validation accuracy, a standard measure of generalization. Since we include “no-forgetting” (i.e., the baseline model) in the search space, we expect to find settings that are at least as good as the baseline.

Below we show features learned by the baseline model (middle row) and our model (bottom row) on two benchmark datasets — biased activity recognition (BAR) and animal categorization (NICO). Feature importance was estimated using post-hoc gradient-based importance scoring (GRAD-CAM), with the orange-red end of the spectrum indicating high importance, while green-blue indicates low importance. Shown below, our trained models focus on the primary object of interest, whereas the baseline model tends to focus on background features that are simpler and spuriously correlated with the label.

Feature importance scoring using GRAD-CAM on activity recognition (BAR) and animal categorization (NICO) generalization benchmarks. Our approach (last row) focuses on the relevant objects in the image, whereas the baseline (ERM; middle row) relies on background features that are spuriously correlated with the label.

Through this ability to learn better, generalizable features, we show substantial gains over a range of relevant baselines on real-world spurious feature benchmark datasets: BAR, CelebA Hair, NICO and ImagenetA, by margins up to 11% (see figure below). More details are available in our paper.

Our feature sieve method improves accuracy by significant margins relative to the nearest baseline for a range of feature generalization benchmark datasets.

Conclusion

We hope that our work on early readouts and their use in feature sieving for generalization will both spur the development of a new class of adversarial feature learning approaches and help improve the generalization capability and robustness of deep learning systems.


Acknowledgements

The work on applying early readouts to debiasing distillation was conducted in collaboration with our academic partners Durga Sivasubramanian, Anmol Reddy and Prof. Ganesh Ramakrishnan at IIT Bombay. We extend our sincere gratitude to Praneeth Netrapalli and Anshul Nasery for their feedback and recommendations. We are also grateful to Nishant Jain, Shreyas Havaldar, Rachit Bansal, Kartikeya Badola, Amandeep Kaur and the whole cohort of pre-doctoral researchers at Google Research India for taking part in research discussions. Special thanks to Tom Small for creating the animation used in this post.

Source: Google AI Blog


MobileDiffusion: Rapid text-to-image generation on-device

Text-to-image diffusion models have shown exceptional capabilities in generating high-quality images from text prompts. However, leading models feature billions of parameters and are consequently expensive to run, requiring powerful desktops or servers (e.g., Stable Diffusion, DALL·E, and Imagen). While recent advancements in inference solutions on Android via MediaPipe and iOS via Core ML have been made in the past year, rapid (sub-second) text-to-image generation on mobile devices has remained out of reach.

To that end, in “MobileDiffusion: Subsecond Text-to-Image Generation on Mobile Devices”, we introduce a novel approach with the potential for rapid text-to-image generation on-device. MobileDiffusion is an efficient latent diffusion model specifically designed for mobile devices. We also adopt DiffusionGAN to achieve one-step sampling during inference, which fine-tunes a pre-trained diffusion model while leveraging a GAN to model the denoising step. We have tested MobileDiffusion on iOS and Android premium devices, and it can run in half a second to generate a 512x512 high-quality image. Its comparably small model size of just 520M parameters makes it uniquely suited for mobile deployment.

      
Rapid text-to-image generation on-device.

Background

The relative inefficiency of text-to-image diffusion models arises from two primary challenges. First, the inherent design of diffusion models requires iterative denoising to generate images, necessitating multiple evaluations of the model. Second, the complexity of the network architecture in text-to-image diffusion models involves a substantial number of parameters, regularly reaching into the billions and resulting in computationally expensive evaluations. As a result, despite the potential benefits of deploying generative models on mobile devices, such as enhancing user experience and addressing emerging privacy concerns, it remains relatively unexplored within the current literature.

The optimization of inference efficiency in text-to-image diffusion models has been an active research area. Previous studies predominantly concentrate on addressing the first challenge, seeking to reduce the number of function evaluations (NFEs). Leveraging advanced numerical solvers (e.g., DPM) or distillation techniques (e.g., progressive distillation, consistency distillation), the number of necessary sampling steps have significantly reduced from several hundreds to single digits. Some recent techniques, like DiffusionGAN and Adversarial Diffusion Distillation, even reduce to a single necessary step.

However, on mobile devices, even a small number of evaluation steps can be slow due to the complexity of model architecture. Thus far, the architectural efficiency of text-to-image diffusion models has received comparatively less attention. A handful of earlier works briefly touches upon this matter, involving the removal of redundant neural network blocks (e.g., SnapFusion). However, these efforts lack a comprehensive analysis of each component within the model architecture, thereby falling short of providing a holistic guide for designing highly efficient architectures.


MobileDiffusion

Effectively overcoming the challenges imposed by the limited computational power of mobile devices requires an in-depth and holistic exploration of the model's architectural efficiency. In pursuit of this objective, our research undertakes a detailed examination of each constituent and computational operation within Stable Diffusion’s UNet architecture. We present a comprehensive guide for crafting highly efficient text-to-image diffusion models culminating in the MobileDiffusion.

The design of MobileDiffusion follows that of latent diffusion models. It contains three components: a text encoder, a diffusion UNet, and an image decoder. For the text encoder, we use CLIP-ViT/L14, which is a small model (125M parameters) suitable for mobile. We then turn our focus to the diffusion UNet and image decoder.


Diffusion UNet

As illustrated in the figure below, diffusion UNets commonly interleave transformer blocks and convolution blocks. We conduct a comprehensive investigation of these two fundamental building blocks. Throughout the study, we control the training pipeline (e.g., data, optimizer) to study the effects of different architectures.

In classic text-to-image diffusion models, a transformer block consists of a self-attention layer (SA) for modeling long-range dependencies among visual features, a cross-attention layer (CA) to capture interactions between text conditioning and visual features, and a feed-forward layer (FF) to post-process the output of attention layers. These transformer blocks hold a pivotal role in text-to-image diffusion models, serving as the primary components responsible for text comprehension. However, they also pose a significant efficiency challenge, given the computational expense of the attention operation, which is quadratic to the sequence length. We follow the idea of UViT architecture, which places more transformer blocks at the bottleneck of the UNet. This design choice is motivated by the fact that the attention computation is less resource-intensive at the bottleneck due to its lower dimensionality.

Our UNet architecture incorporates more transformers in the middle, and skips self-attention (SA) layers at higher resolutions.

Convolution blocks, in particular ResNet blocks, are deployed at each level of the UNet. While these blocks are instrumental for feature extraction and information flow, the associated computational costs, especially at high-resolution levels, can be substantial. One proven approach in this context is separable convolution. We observed that replacing regular convolution layers with lightweight separable convolution layers in the deeper segments of the UNet yields similar performance.

In the figure below, we compare the UNets of several diffusion models. Our MobileDiffusion exhibits superior efficiency in terms of FLOPs (floating-point operations) and number of parameters.

Comparison of some diffusion UNets.

Image decoder

In addition to the UNet, we also optimized the image decoder. We trained a variational autoencoder (VAE) to encode an RGB image to an 8-channel latent variable, with 8× smaller spatial size of the image. A latent variable can be decoded to an image and gets 8× larger in size. To further enhance efficiency, we design a lightweight decoder architecture by pruning the original’s width and depth. The resulting lightweight decoder leads to a significant performance boost, with nearly 50% latency improvement and better quality. For more details, please refer to our paper.

VAE reconstruction. Our VAE decoders have better visual quality than SD (Stable Diffusion).

Decoder   #Params (M)     PSNR↑     SSIM↑     LPIPS↓  
SD 49.5 26.7 0.76 0.037
Ours 39.3 30.0 0.83 0.032
Ours-Lite     9.8 30.2 0.84 0.032

Quality evaluation of VAE decoders. Our lite decoder is much smaller than SD, with better quality metrics, including peak signal-to-noise ratio (PSNR), structural similarity index measure (SSIM), and Learned Perceptual Image Patch Similarity (LPIPS).

One-step sampling

In addition to optimizing the model architecture, we adopt a DiffusionGAN hybrid to achieve one-step sampling. Training DiffusionGAN hybrid models for text-to-image generation encounters several intricacies. Notably, the discriminator, a classifier distinguishing real data and generated data, must make judgments based on both texture and semantics. Moreover, the cost of training text-to-image models can be extremely high, particularly in the case of GAN-based models, where the discriminator introduces additional parameters. Purely GAN-based text-to-image models (e.g., StyleGAN-T, GigaGAN) confront similar complexities, resulting in highly intricate and expensive training.

To overcome these challenges, we use a pre-trained diffusion UNet to initialize the generator and discriminator. This design enables seamless initialization with the pre-trained diffusion model. We postulate that the internal features within the diffusion model contain rich information of the intricate interplay between textual and visual data. This initialization strategy significantly streamlines the training.

The figure below illustrates the training procedure. After initialization, a noisy image is sent to the generator for one-step diffusion. The result is evaluated against ground truth with a reconstruction loss, similar to diffusion model training. We then add noise to the output and send it to the discriminator, whose result is evaluated with a GAN loss, effectively adopting the GAN to model a denoising step. By using pre-trained weights to initialize the generator and the discriminator, the training becomes a fine-tuning process, which converges in less than 10K iterations.

Illustration of DiffusionGAN fine-tuning.

Results

Below we show example images generated by our MobileDiffusion with DiffusionGAN one-step sampling. With such a compact model (520M parameters in total), MobileDiffusion can generate high-quality diverse images for various domains.

Images generated by our MobileDiffusion

We measured the performance of our MobileDiffusion on both iOS and Android devices, using different runtime optimizers. The latency numbers are reported below. We see that MobileDiffusion is very efficient and can run within half a second to generate a 512x512 image. This lightning speed potentially enables many interesting use cases on mobile devices.

Latency measurements (s) on mobile devices.

Conclusion

With superior efficiency in terms of latency and size, MobileDiffusion has the potential to be a very friendly option for mobile deployments given its capability to enable a rapid image generation experience while typing text prompts. And we will ensure any application of this technology will be in-line with Google’s responsible AI practices.


Acknowledgments

We like to thank our collaborators and contributors that helped bring MobileDiffusion to on-device: Zhisheng Xiao, Yanwu Xu, Jiuqiang Tang, Haolin Jia, Lutz Justen, Daniel Fenner, Ronald Wotzlaw, Jianing Wei, Raman Sarokin, Juhyun Lee, Andrei Kulik, Chuo-Ling Chang, and Matthias Grundmann.

Source: Google AI Blog


Mixed-input matrix multiplication performance optimizations

AI-driven technologies are weaving themselves into the fabric of our daily routines, with the potential to enhance our access to knowledge and boost our overall productivity. The backbone of these applications lies in large language models (LLMs). LLMs are memory-intensive and typically require specialized hardware accelerators to efficiently deliver tens of exaflops of computing power. This blog post shows how we can start addressing the computational challenges by utilizing memory more effectively.

The bulk of an LLM’s memory and compute are consumed by weights in matrix multiplication operations. Using narrower data types reduces memory consumption. For example, storing weights in the 8-bit integer (i.e., U8 or S8) data type reduces the memory footprint by 4× relative to single-precision (F32) and 2× relative to half-precision (F16) or bfloat16 (BF16). Furthermore, previous work has shown that LLM models running matrix multiplications with weights in S8 and input in F16 (preserving higher precision of the user-input) is an effective method for increasing the efficiency with acceptable trade-offs in accuracy. This technique is known as weight-only quantization and requires efficient implementation of matrix multiplication with mixed-inputs, e.g., half-precision input multiplied with 8-bits integer. Hardware accelerators, including GPUs, support a fixed set of data types, and thus, mixed-input matrix multiplication requires software transformations to map to the hardware operations.

To that end, in this blog we focus on mapping mixed-input matrix multiplication onto the NVIDIA Ampere architecture. We present software techniques addressing data type conversion and layout conformance to map mixed-input matrix multiplication efficiently onto hardware-supported data types and layouts. Our results show that the overhead of additional work in software is minimal and enables performance close to the peak hardware capabilities. The software techniques described here are released in the open-source NVIDIA/CUTLASS repository.

Memory footprint for an 175B parameter LLM model with various data types formats.

The matrix-multiply-accumulate operation

Modern AI hardware accelerators such as Google’s TPU and NVIDIA’s GPU multiply matrices natively in the hardware by targeting Tensor Cores, which are specialized processing elements to accelerate matrix operations, particularly for AI workloads. In this blog, we focus on NVIDIA Ampere Tensor Cores, which provide the matrix-multiply-accumulate (mma) operation. For the rest of the blog the reference to mma is for Ampere Tensor Cores. The supported data types, shapes, and data layout of the two input matrices (called operands) for the mma operation are fixed in hardware. This means that matrix multiplications with various data types and larger shapes are implemented in the software by tiling the problem onto hardware-supported data types, shapes, and layouts.

The Tensor Core mma operation is defined by specifying two input matrices (e.g., A & B, shown below) to produce a result matrix, C. The mma operation natively supports mixed-precision. Mixed-precision Tensor Cores allow mixing input (A and B) data type with the result (C) data type. In contrast, mixed-input matrix multiplication involves mixing the input data types, and it is not supported by the hardware, so it needs to be implemented in the software.

Tensor Core operation of M-by-N-by-K on input matrix A of M-by-K and matrix B of K-by-N produces output matrix C of M-by-N.

Challenges of mixed-input matrix multiplication

To simplify the discussion, we restrict to a specific example of mixed-input matrix multiplication: F16 for user input and U8 for the model weights (written as F16 * U8). The techniques described here work for various combinations of mixed-input data types.

A GPU programmer can access a hierarchy of memory, including global memory, shared memory, and registers, which are arranged in order of decreasing capacity but increasing speed. NVIDIA Ampere Tensor Core mma operations consume input matrices from registers. Furthermore, input and output matrices are required to conform to a layout of data within a group of 32 threads known as a warp. The supported data type and layout within a warp are fixed for an mma operation, so to implement mixed-input multiplication efficiently, it is necessary to solve the challenges of data type conversion and layout conformance in software.


Data type conversion

The mma operation requires two input matrices with the same data type. Thus, mixed-input matrix multiplication, where one of the operands is stored in U8 in global memory and other in F16, requires a data type conversion from U8 to F16. The conversion will bring two operands to F16, mapping the mixed-input matrix multiplication to hardware-supported mixed-precision Tensor Cores. Given the large number of weights, there are a large number of such operations, and our techniques show how to reduce their latency and improve performance.


Layout conformance

The mma operation also requires the layout of two input matrices, within the registers of a warp, to be conformat with hardware specification. The layout for the input matrix B of U8 data type in mixed-input matrix multiplication (F16 * U8) needs to conform with the converted F16 data type. This is called layout conformance and needs to be achieved in the software.

The figure below shows an mma operation consuming matrix A and matrix B from registers to produce matrix C in registers, distributed across one warp. The thread T0 is highlighted and zoomed in to show the weight matrix B goes through data type conversion and needs a layout conformance to be able to map to the hardware-supported Tensor Core operation.

The mapping of mixed-input (F32 = F16 * U8) operation in software to natively supported warp-level Tensor Cores in hardware (F32 = F16 * F16). (Original figure source Developing CUDA kernels to push Tensor Cores to the Absolute Limit on NVIDIA A100.)

Software strategies addressing challenges

A typical data type conversion involves a sequence of operations on 32-bit registers, shown below. Each rectangular block represents a register and the adjoining text are the operations. The entire sequence shows the conversion from 4xU8 to 2x(2xF16). The sequence involves roughly 10 operations.

NumericArrayConvertor from 4xU8 to 2x(2xF16) in 32-bit registers.

There are many ways of achieving layout conformance. Two of the existing solutions are:

  1. Narrower bitwidth shared memory loads: In this approach, threads issue narrow bitwidth memory loads moving the U8 data from shared memory to registers. This results in two 32-bit registers, with each register containing 2xF16 values (shown above for the matrix B’s thread T0). The narrower shared memory load achieves layout conformance directly into registers without needing any shuffles; however, it does not utilize the full shared memory bandwidth.
  2. Pre-processing in global memory: An alternative strategy involves rearranging the data within the global memory (one level above the shared memory in memory hierarchy), allowing wider shared memory loads. This approach maximizes the shared memory bandwidth utilization and ensures that the data is loaded in a conformant layout directly in the registers. Although the rearrangement process can be executed offline prior to the LLM deployment, ensuring no impact on the application performance, it introduces an additional, non-trivial hardware-specific pre-processing step that requires an extra program to rearrange the data. NVIDIA/FasterTransformer adopts this method to effectively address layout conformance challenges.

Optimized software strategies

To further optimize and reduce the overhead of data type conversion and layout conformance, we have implemented FastNumericArrayConvertor and FragmentShuffler, respectively.

FastNumericArrayConvertor operates on 4xU8 in 32-bit registers without unpacking individual 1xU8 values. Furthermore, it uses less expensive arithmetic operations which reduces the number of instructions and increases the speed of the conversion.

The conversion sequence for U8-to-F16 is shown below. The operations use packed 32b registers, avoiding explicit unpacking and packing. FastNumericArrayConvertor uses the permute byte to rearrange bytes of 4xU8 into two registers. Additionally, FastNumericArrayConvertor does not use expensive integer to floating-point conversion instructions and employs vectorized operations to obtain the packed results in two 32-bit registers containing 2x(2xF16) values. The FastNumericArrayConvertor for U8-to-F16 approximately uses six operations, a 1.6× reduction relative to the approach shown above.

FastNumericArrayConvertor utilizes permute bytes and packed arithmetic, reducing the number of instructions in the data type conversion.

FragmentShuffler handles the layout conformance by shuffling data in a way that allows the use of wider bitwidth load operation, increasing shared memory bandwidth utilization and reducing the total number of operations.

NVIDIA Ampere architecture provides a load matrix instruction (ldmatrix). The ldmatrix is a warp-level operation, where 32 threads of a warp move the data from shared memory to registers in the shape and layout that mma matrix A and B consume. The use of ldmatrix reduces the number of load instructions and increases the memory bandwidth utilization. Since the ldmatrix instruction moves U8 data to registers, the layout after the load conforms with U8*U8 mma operation, and not with F16*F16 mma operation. We implemented FragmentShuffler to rearrange the data within registers using shuffle (shfl.sync) operations to achieve the layout conformance.

The most significant contribution of this work is to achieve layout conformance through register shuffles, avoiding offline pre-processing in global memory or narrower bitwidth shared memory loads. Furthermore, we provide implementations for FastNumericArrayConvertor covering data type conversion from U8-to-F16, S8-to-F16, U8-to-BF16, and S8-to-BF16.


Performance results

We measured the performance of eight mixed-input variants of our method (shown below in blue and red; varying the data types of matrix A and B) and two mixed-precision data types (shown in green) on an NVIDIA A100 SXM chip. The performance results are shown in FLOPS (higher is better). Notably, the first eight matrix-multipications require additional operations relative to the last two, because the mixed-precision variants directly target hardware-accelerated Tensor Core operations and do not need data type conversion and layout conformance. Even so, our approach demonstrates mixed-input matrix multiplication performance only slightly below or on par with mixed-precision.

Mixed-input matrix multiplication performance on NVIDIA A100 40GB SMX4 chip for a compute-bound matrix problem shape m=3456, n=4096, k=2048.

Acknowledgements

We would like to mention several folks who have contributed through technical brainstorming and improving the blog post including, Quentin Colombet, Jacques Pienaar, Allie Culp, Calin Cascaval, Ashish Gondimalla, Matt Walsh, Marek Kolodziej, and Aman Bhatia. We would like to thank our NVIDIA partners Rawn Henry, Pradeep Ramani, Vijay Thakkar, Haicheng Wu, Andrew Kerr, Matthew Nicely, and Vartika Singh.

Source: Google AI Blog


Exphormer: Scaling transformers for graph-structured data

Graphs, in which objects and their relations are represented as nodes (or vertices) and edges (or links) between pairs of nodes, are ubiquitous in computing and machine learning (ML). For example, social networks, road networks, and molecular structure and interactions are all domains in which underlying datasets have a natural graph structure. ML can be used to learn the properties of nodes, edges, or entire graphs.

A common approach to learning on graphs are graph neural networks (GNNs), which operate on graph data by applying an optimizable transformation on node, edge, and global attributes. The most typical class of GNNs operates via a message-passing framework, whereby each layer aggregates the representation of a node with those of its immediate neighbors.

Recently, graph transformer models have emerged as a popular alternative to message-passing GNNs. These models build on the success of Transformer architectures in natural language processing (NLP), adapting them to graph-structured data. The attention mechanism in graph transformers can be modeled by an interaction graph, in which edges represent pairs of nodes that attend to each other. Unlike message passing architectures, graph transformers have an interaction graph that is separate from the input graph. The typical interaction graph is a complete graph, which signifies a full attention mechanism that models direct interactions between all pairs of nodes. However, this creates quadratic computational and memory bottlenecks that limit the applicability of graph transformers to datasets on small graphs with at most a few thousand nodes. Making graph transformers scalable has been considered one of the most important research directions in the field (see the first open problem here).

A natural remedy is to use a sparse interaction graph with fewer edges. Many sparse and efficient transformers have been proposed to eliminate the quadratic bottleneck for sequences, however, they do not generally extend to graphs in a principled manner.

In “Exphormer: Sparse Transformers for Graphs”, presented at ICML 2023, we address the scalability challenge by introducing a sparse attention framework for transformers that is designed specifically for graph data. The Exphormer framework makes use of expander graphs, a powerful tool from spectral graph theory, and is able to achieve strong empirical results on a wide variety of datasets. Our implementation of Exphormer is now available on GitHub.


Expander graphs

A key idea at the heart of Exphormer is the use of expander graphs, which are sparse yet well-connected graphs that have some useful properties — 1) the matrix representation of the graphs have similar linear-algebraic properties as a complete graph, and 2) they exhibit rapid mixing of random walks, i.e., a small number of steps in a random walk from any starting node is enough to ensure convergence to a “stable” distribution on the nodes of the graph. Expanders have found applications to diverse areas, such as algorithms, pseudorandomness, complexity theory, and error-correcting codes.

A common class of expander graphs are d-regular expanders, in which there are d edges from every node (i.e., every node has degree d). The quality of an expander graph is measured by its spectral gap, an algebraic property of its adjacency matrix (a matrix representation of the graph in which rows and columns are indexed by nodes and entries indicate whether pairs of nodes are connected by an edge). Those that maximize the spectral gap are known as Ramanujan graphs — they achieve a gap of d - 2*√(d-1), which is essentially the best possible among d-regular graphs. A number of deterministic and randomized constructions of Ramanujan graphs have been proposed over the years for various values of d. We use a randomized expander construction of Friedman, which produces near-Ramanujan graphs.

Expander graphs are at the heart of Exphormer. A good expander is sparse yet exhibits rapid mixing of random walks, making its global connectivity suitable for an interaction graph in a graph transformer model.

Exphormer replaces the dense, fully-connected interaction graph of a standard Transformer with edges of a sparse d-regular expander graph. Intuitively, the spectral approximation and mixing properties of an expander graph allow distant nodes to communicate with each other after one stacks multiple attention layers in a graph transformer architecture, even though the nodes may not attend to each other directly. Furthermore, by ensuring that d is constant (independent of the size of the number of nodes), we obtain a linear number of edges in the resulting interaction graph.


Exphormer: Constructing a sparse interaction graph

Exphormer combines expander edges with the input graph and virtual nodes. More specifically, the sparse attention mechanism of Exphormer builds an interaction graph consisting of three types of edges:

  • Edges from the input graph (local attention)
  • Edges from a constant-degree expander graph (expander attention)
  • Edges from every node to a small set of virtual nodes (global attention)
Exphormer builds an interaction graph by combining three types of edges. The resulting graph has good connectivity properties and retains the inductive bias of the input dataset graph while still remaining sparse.

Each component serves a specific purpose: the edges from the input graph retain the inductive bias from the input graph structure (which typically gets lost in a fully-connected attention module). Meanwhile, expander edges allow good global connectivity and random walk mixing properties (which spectrally approximate the complete graph with far fewer edges). Finally, virtual nodes serve as global “memory sinks” that can directly communicate with every node. While this results in additional edges from each virtual node equal to the number of nodes in the input graph, the resulting graph is still sparse. The degree of the expander graph and the number of virtual nodes are hyperparameters to tune for improving the quality metrics.

Furthermore, since we use an expander graph of constant degree and a small constant number of virtual nodes for the global attention, the resulting sparse attention mechanism is linear in the size of the original input graph, i.e., it models a number of direct interactions on the order of the total number of nodes and edges.

We additionally show that Exphormer is as expressive as the dense transformer and obeys universal approximation properties. In particular, when the sparse attention graph of Exphormer is augmented with self loops (edges connecting a node to itself), it can universally approximate continuous functions [1, 2].


Relation to sparse Transformers for sequences

It is interesting to compare Exphormer to sparse attention methods for sequences. Perhaps the architecture most conceptually similar to our approach is BigBird, which builds an interaction graph by combining different components. BigBird also uses virtual nodes, but, unlike Exphormer, it uses window attention and random attention from an Erdős-Rényi random graph model for the remaining components.

Window attention in BigBird looks at the tokens surrounding a token in a sequence — the local neighborhood attention in Exphormer can be viewed as a generalization of window attention to graphs.

The Erdős-Rényi graph on n nodes, G(n, p), which connects every pair of nodes independently with probability p, also functions as an expander graph for suitably high p. However, a superlinear number of edges (Ω(n log n)) is needed to ensure that an Erdős-Rényi graph is connected, let alone a good expander. On the other hand, the expanders used in Exphormer have only a linear number of edges.


Experimental results

Earlier works have shown the use of full graph Transformer-based models on datasets with graphs of size up to 5,000 nodes. To evaluate the performance of Exphormer, we build upon the celebrated GraphGPS framework [3], which combines both message passing and graph transformers and achieves state-of-the-art performance on a number of datasets. We show that replacing dense attention with Exphormer for the graph attention component in the GraphGPS framework allows one to achieve models with comparable or better performance, often with fewer trainable parameters.

Furthermore, Exphormer notably allows graph transformer architectures to scale well beyond the usual graph size limits mentioned above. Exphormer can scale up to datasets of 10,000+ node graphs, such as the Coauthor dataset, and even beyond to larger graphs such as the well-known ogbn-arxiv dataset, a citation network, which consists of 170K nodes and 1.1 million edges.

Results comparing Exphormer to standard GraphGPS on the five Long Range Graph Benchmark datasets. We note that Exphormer achieved state-of-the-art results on four of the five datasets (PascalVOC-SP, COCO-SP, Peptides-Struct, PCQM-Contact) at the time of the paper’s publication.

Finally, we observe that Exphormer, which creates an overlay graph of small diameter via expanders, exhibits the ability to effectively learn long-range dependencies. The Long Range Graph Benchmark is a suite of five graph learning datasets designed to measure the ability of models to capture long-range interactions. Results show that Exphormer-based models outperform standard GraphGPS models (which were previously state-of-the-art on four out of five datasets at the time of publication).


Conclusion

Graph transformers have emerged as an important architecture for ML that adapts the highly successful sequence-based transformers used in NLP to graph-structured data. Scalability has, however, proven to be a major challenge in enabling the use of graph transformers on datasets with large graphs. In this post, we have presented Exphormer, a sparse attention framework that uses expander graphs to improve scalability of graph transformers. Exphormer is shown to have important theoretical properties and exhibit strong empirical performance, particularly on datasets where it is crucial to learn long range dependencies. For more information, we point the reader to a short presentation video from ICML 2023.


Acknowledgements

We thank our research collaborators Hamed Shirzad and Danica J. Sutherland from The University of British Columbia as well as Ali Kemal Sinop from Google Research. Special thanks to Tom Small for creating the animation used in this post.

Source: Google AI Blog


Introducing ASPIRE for selective prediction in LLMs

In the fast-evolving landscape of artificial intelligence, large language models (LLMs) have revolutionized the way we interact with machines, pushing the boundaries of natural language understanding and generation to unprecedented heights. Yet, the leap into high-stakes decision-making applications remains a chasm too wide, primarily due to the inherent uncertainty of model predictions. Traditional LLMs generate responses recursively, yet they lack an intrinsic mechanism to assign a confidence score to these responses. Although one can derive a confidence score by summing up the probabilities of individual tokens in the sequence, traditional approaches typically fall short in reliably distinguishing between correct and incorrect answers. But what if LLMs could gauge their own confidence and only make predictions when they're sure?

Selective prediction aims to do this by enabling LLMs to output an answer along with a selection score, which indicates the probability that the answer is correct. With selective prediction, one can better understand the reliability of LLMs deployed in a variety of applications. Prior research, such as semantic uncertainty and self-evaluation, has attempted to enable selective prediction in LLMs. A typical approach is to use heuristic prompts like “Is the proposed answer True or False?” to trigger self-evaluation in LLMs. However, this approach may not work well on challenging question answering (QA) tasks.

The OPT-2.7B model incorrectly answers a question from the TriviaQA dataset: “Which vitamin helps regulate blood clotting?” with “Vitamin C”. Without selective prediction, LLMs may output the wrong answer which, in this case, could lead users to take the wrong vitamin. With selective prediction, LLMs will output an answer along with a selection score. If the selection score is low (0.1), LLMs will further output “I don’t know!” to warn users not to trust it or verify it using other sources.

In "Adaptation with Self-Evaluation to Improve Selective Prediction in LLMs", presented at Findings of EMNLP 2023, we introduce ASPIRE — a novel framework meticulously designed to enhance the selective prediction capabilities of LLMs. ASPIRE fine-tunes LLMs on QA tasks via parameter-efficient fine-tuning, and trains them to evaluate whether their generated answers are correct. ASPIRE allows LLMs to output an answer along with a confidence score for that answer. Our experimental results demonstrate that ASPIRE significantly outperforms state-of-the-art selective prediction methods on a variety of QA datasets, such as the CoQA benchmark.


The mechanics of ASPIRE

Imagine teaching an LLM to not only answer questions but also evaluate those answers — akin to a student verifying their answers in the back of the textbook. That's the essence of ASPIRE, which involves three stages: (1) task-specific tuning, (2) answer sampling, and (3) self-evaluation learning.

Task-specific tuning: ASPIRE performs task-specific tuning to train adaptable parameters (θp) while freezing the LLM. Given a training dataset for a generative task, it fine-tunes the pre-trained LLM to improve its prediction performance. Towards this end, parameter-efficient tuning techniques (e.g., soft prompt tuning and LoRA) might be employed to adapt the pre-trained LLM on the task, given their effectiveness in obtaining strong generalization with small amounts of target task data. Specifically, the LLM parameters (θ) are frozen and adaptable parameters (θp) are added for fine-tuning. Only θp are updated to minimize the standard LLM training loss (e.g., cross-entropy). Such fine-tuning can improve selective prediction performance because it not only improves the prediction accuracy, but also enhances the likelihood of correct output sequences.

Answer sampling: After task-specific tuning, ASPIRE uses the LLM with the learned θp to generate different answers for each training question and create a dataset for self-evaluation learning. We aim to generate output sequences that have a high likelihood. We use beam search as the decoding algorithm to generate high-likelihood output sequences and the Rouge-L metric to determine if the generated output sequence is correct.

Self-evaluation learning: After sampling high-likelihood outputs for each query, ASPIRE adds adaptable parameters (θs) and only fine-tunes θs for learning self-evaluation. Since the output sequence generation only depends on θ and θp, freezing θ and the learned θp can avoid changing the prediction behaviors of the LLM when learning self-evaluation. We optimize θs such that the adapted LLM can distinguish between correct and incorrect answers on their own.

The three stages of the ASPIRE framework.

In the proposed framework, θp and θs can be trained using any parameter-efficient tuning approach. In this work, we use soft prompt tuning, a simple yet effective mechanism for learning “soft prompts” to condition frozen language models to perform specific downstream tasks more effectively than traditional discrete text prompts. The driving force behind this approach lies in the recognition that if we can develop prompts that effectively stimulate self-evaluation, it should be possible to discover these prompts through soft prompt tuning in conjunction with targeted training objectives.

Implementation of the ASPIRE framework via soft prompt tuning. We first generate the answer to the question with the first soft prompt and then compute the learned self-evaluation score with the second soft prompt.

After training θp and θs, we obtain the prediction for the query via beam search decoding. We then define a selection score that combines the likelihood of the generated answer with the learned self-evaluation score (i.e., the likelihood of the prediction being correct for the query) to make selective predictions.


Results

To demonstrate ASPIRE’s efficacy, we evaluate it across three question-answering datasets — CoQA, TriviaQA, and SQuAD — using various open pre-trained transformer (OPT) models. By training θp with soft prompt tuning, we observed a substantial hike in the LLMs' accuracy. For example, the OPT-2.7B model adapted with ASPIRE demonstrated improved performance over the larger, pre-trained OPT-30B model using the CoQA and SQuAD datasets. These results suggest that with suitable adaptations, smaller LLMs might have the capability to match or potentially surpass the accuracy of larger models in some scenarios.

When delving into the computation of selection scores with fixed model predictions, ASPIRE received a higher AUROC score (the probability that a randomly chosen correct output sequence has a higher selection score than a randomly chosen incorrect output sequence) than baseline methods across all datasets. For example, on the CoQA benchmark, ASPIRE improves the AUROC from 51.3% to 80.3% compared to the baselines.

An intriguing pattern emerged from the TriviaQA dataset evaluations. While the pre-trained OPT-30B model demonstrated higher baseline accuracy, its performance in selective prediction did not improve significantly when traditional self-evaluation methods — Self-eval and P(True) — were applied. In contrast, the smaller OPT-2.7B model, when enhanced with ASPIRE, outperformed in this aspect. This discrepancy underscores a vital insight: larger LLMs utilizing conventional self-evaluation techniques may not be as effective in selective prediction as smaller, ASPIRE-enhanced models.

Our experimental journey with ASPIRE underscores a pivotal shift in the landscape of LLMs: The capacity of a language model is not the be-all and end-all of its performance. Instead, the effectiveness of models can be drastically improved through strategic adaptations, allowing for more precise, confident predictions even in smaller models. As a result, ASPIRE stands as a testament to the potential of LLMs that can judiciously ascertain their own certainty and decisively outperform larger counterparts in selective prediction tasks.


Conclusion

In conclusion, ASPIRE is not just another framework; it's a vision of a future where LLMs can be trusted partners in decision-making. By honing the selective prediction performance, we're inching closer to realizing the full potential of AI in critical applications.

Our research has opened new doors, and we invite the community to build upon this foundation. We're excited to see how ASPIRE will inspire the next generation of LLMs and beyond. To learn more about our findings, we encourage you to read our paper and join us in this thrilling journey towards creating a more reliable and self-aware AI.


Acknowledgments

We gratefully acknowledge the contributions of Sayna Ebrahimi, Sercan O Arik, Tomas Pfister, and Somesh Jha.

Source: Google AI Blog


AMIE: A research AI system for diagnostic medical reasoning and conversations

The physician-patient conversation is a cornerstone of medicine, in which skilled and intentional communication drives diagnosis, management, empathy and trust. AI systems capable of such diagnostic dialogues could increase availability, accessibility, quality and consistency of care by being useful conversational partners to clinicians and patients alike. But approximating clinicians’ considerable expertise is a significant challenge.

Recent progress in large language models (LLMs) outside the medical domain has shown that they can plan, reason, and use relevant context to hold rich conversations. However, there are many aspects of good diagnostic dialogue that are unique to the medical domain. An effective clinician takes a complete “clinical history” and asks intelligent questions that help to derive a differential diagnosis. They wield considerable skill to foster an effective relationship, provide information clearly, make joint and informed decisions with the patient, respond empathically to their emotions, and support them in the next steps of care. While LLMs can accurately perform tasks such as medical summarization or answering medical questions, there has been little work specifically aimed towards developing these kinds of conversational diagnostic capabilities.

Inspired by this challenge, we developed Articulate Medical Intelligence Explorer (AMIE), a research AI system based on a LLM and optimized for diagnostic reasoning and conversations. We trained and evaluated AMIE along many dimensions that reflect quality in real-world clinical consultations from the perspective of both clinicians and patients. To scale AMIE across a multitude of disease conditions, specialties and scenarios, we developed a novel self-play based simulated diagnostic dialogue environment with automated feedback mechanisms to enrich and accelerate its learning process. We also introduced an inference time chain-of-reasoning strategy to improve AMIE’s diagnostic accuracy and conversation quality. Finally, we tested AMIE prospectively in real examples of multi-turn dialogue by simulating consultations with trained actors.

AMIE was optimized for diagnostic conversations, asking questions that help to reduce its uncertainty and improve diagnostic accuracy, while also balancing this with other requirements of effective clinical communication, such as empathy, fostering a relationship, and providing information clearly.

Evaluation of conversational diagnostic AI

Besides developing and optimizing AI systems themselves for diagnostic conversations, how to assess such systems is also an open question. Inspired by accepted tools used to measure consultation quality and clinical communication skills in real-world settings, we constructed a pilot evaluation rubric to assess diagnostic conversations along axes pertaining to history-taking, diagnostic accuracy, clinical management, clinical communication skills, relationship fostering and empathy.

We then designed a randomized, double-blind crossover study of text-based consultations with validated patient actors interacting either with board-certified primary care physicians (PCPs) or the AI system optimized for diagnostic dialogue. We set up our consultations in the style of an objective structured clinical examination (OSCE), a practical assessment commonly used in the real world to examine clinicians’ skills and competencies in a standardized and objective way. In a typical OSCE, clinicians might rotate through multiple stations, each simulating a real-life clinical scenario where they perform tasks such as conducting a consultation with a standardized patient actor (trained carefully to emulate a patient with a particular condition). Consultations were performed using a synchronous text-chat tool, mimicking the interface familiar to most consumers using LLMs today.

AMIE is a research AI system based on LLMs for diagnostic reasoning and dialogue.

AMIE: an LLM-based conversational diagnostic research AI system

We trained AMIE on real-world datasets comprising medical reasoning, medical summarization and real-world clinical conversations.

It is feasible to train LLMs using real-world dialogues developed by passively collecting and transcribing in-person clinical visits, however, two substantial challenges limit their effectiveness in training LLMs for medical conversations. First, existing real-world data often fails to capture the vast range of medical conditions and scenarios, hindering the scalability and comprehensiveness. Second, the data derived from real-world dialogue transcripts tends to be noisy, containing ambiguous language (including slang, jargon, humor and sarcasm), interruptions, ungrammatical utterances, and implicit references.

To address these limitations, we designed a self-play based simulated learning environment with automated feedback mechanisms for diagnostic medical dialogue in a virtual care setting, enabling us to scale AMIE’s knowledge and capabilities across many medical conditions and contexts. We used this environment to iteratively fine-tune AMIE with an evolving set of simulated dialogues in addition to the static corpus of real-world data described.

This process consisted of two self-play loops: (1) an “inner” self-play loop, where AMIE leveraged in-context critic feedback to refine its behavior on simulated conversations with an AI patient simulator; and (2) an “outer” self-play loop where the set of refined simulated dialogues were incorporated into subsequent fine-tuning iterations. The resulting new version of AMIE could then participate in the inner loop again, creating a virtuous continuous learning cycle.

Further, we also employed an inference time chain-of-reasoning strategy which enabled AMIE to progressively refine its response conditioned on the current conversation to arrive at an informed and grounded reply.

AMIE uses a novel self-play based simulated dialogue learning environment to improve the quality of diagnostic dialogue across a multitude of disease conditions, specialities and patient contexts.

We tested performance in consultations with simulated patients (played by trained actors), compared to those performed by 20 real PCPs using the randomized approach described above. AMIE and PCPs were assessed from the perspectives of both specialist attending physicians and our simulated patients in a randomized, blinded crossover study that included 149 case scenarios from OSCE providers in Canada, the UK and India in a diverse range of specialties and diseases.

Notably, our study was not designed to emulate either traditional in-person OSCE evaluations or the ways clinicians usually use text, email, chat or telemedicine. Instead, our experiment mirrored the most common way consumers interact with LLMs today, a potentially scalable and familiar mechanism for AI systems to engage in remote diagnostic dialogue.

Overview of the randomized study design to perform a virtual remote OSCE with simulated patients via online multi-turn synchronous text chat.

Performance of AMIE

In this setting, we observed that AMIE performed simulated diagnostic conversations at least as well as PCPs when both were evaluated along multiple clinically-meaningful axes of consultation quality. AMIE had greater diagnostic accuracy and superior performance for 28 of 32 axes from the perspective of specialist physicians, and 24 of 26 axes from the perspective of patient actors.

AMIE outperformed PCPs on multiple evaluation axes for diagnostic dialogue in our evaluations.
Specialist-rated top-k diagnostic accuracy. AMIE and PCPs top-k differential diagnosis (DDx) accuracy are compared across 149 scenarios with respect to the ground truth diagnosis (a) and all diagnoses listed within the accepted differential diagnoses (b). Bootstrapping (n=10,000) confirms all top-k differences between AMIE and PCP DDx accuracy are significant with p <0.05 after false discovery rate (FDR) correction.
Diagnostic conversation and reasoning qualities as assessed by specialist physicians. On 28 out of 32 axes, AMIE outperformed PCPs while being comparable on the rest.

Limitations

Our research has several limitations and should be interpreted with appropriate caution. Firstly, our evaluation technique likely underestimates the real-world value of human conversations, as the clinicians in our study were limited to an unfamiliar text-chat interface, which permits large-scale LLM–patient interactions but is not representative of usual clinical practice. Secondly, any research of this type must be seen as only a first exploratory step on a long journey. Transitioning from a LLM research prototype that we evaluated in this study to a safe and robust tool that could be used by people and those who provide care for them will require significant additional research. There are many important limitations to be addressed, including experimental performance under real-world constraints and dedicated exploration of such important topics as health equity and fairness, privacy, robustness, and many more, to ensure the safety and reliability of the technology.


AMIE as an aid to clinicians

In a recently released preprint, we evaluated the ability of an earlier iteration of the AMIE system to generate a DDx alone or as an aid to clinicians. Twenty (20) generalist clinicians evaluated 303 challenging, real-world medical cases sourced from the New England Journal of Medicine (NEJM) ClinicoPathologic Conferences (CPCs). Each case report was read by two clinicians randomized to one of two assistive conditions: either assistance from search engines and standard medical resources, or AMIE assistance in addition to these tools. All clinicians provided a baseline, unassisted DDx prior to using the respective assistive tools.

Assisted randomized reader study setup to investigate the assistive effect of AMIE to clinicians in solving complex diagnostic case challenges from the New England Journal of Medicine.

AMIE exhibited standalone performance that exceeded that of unassisted clinicians (top-10 accuracy 59.1% vs. 33.6%, p= 0.04). Comparing the two assisted study arms, the top-10 accuracy was higher for clinicians assisted by AMIE, compared to clinicians without AMIE assistance (24.6%, p<0.01) and clinicians with search (5.45%, p=0.02). Further, clinicians assisted by AMIE arrived at more comprehensive differential lists than those without AMIE assistance.

In addition to strong standalone performance, using the AMIE system led to significant assistive effect and improvements in diagnostic accuracy of the clinicians in solving these complex case challenges.

It's worth noting that NEJM CPCs are not representative of everyday clinical practice. They are unusual case reports in only a few hundred individuals so offer limited scope for probing important issues like equity or fairness.


Bold and responsible research in healthcare — the art of the possible

Access to clinical expertise remains scarce around the world. While AI has shown great promise in specific clinical applications, engagement in the dynamic, conversational diagnostic journeys of clinical practice requires many capabilities not yet demonstrated by AI systems. Doctors wield not only knowledge and skill but a dedication to myriad principles, including safety and quality, communication, partnership and teamwork, trust, and professionalism. Realizing these attributes in AI systems is an inspiring challenge that should be approached responsibly and with care. AMIE is our exploration of the “art of the possible”, a research-only system for safely exploring a vision of the future where AI systems might be better aligned with attributes of the skilled clinicians entrusted with our care. It is early experimental-only work, not a product, and has several limitations that we believe merit rigorous and extensive further scientific studies in order to envision a future in which conversational, empathic and diagnostic AI systems might become safe, helpful and accessible.


Acknowledgements

The research described here is joint work across many teams at Google Research and Google Deepmind. We are grateful to all our co-authors - Tao Tu, Mike Schaekermann, Anil Palepu, Daniel McDuff, Jake Sunshine, Khaled Saab, Jan Freyberg, Ryutaro Tanno, Amy Wang, Brenna Li, Mohamed Amin, Sara Mahdavi, Karan Sighal, Shekoofeh Azizi, Nenad Tomasev, Yun Liu, Yong Cheng, Le Hou, Albert Webson, Jake Garrison, Yash Sharma, Anupam Pathak, Sushant Prakash, Philip Mansfield, Shwetak Patel, Bradley Green, Ewa Dominowska, Renee Wong, Juraj Gottweis, Dale Webster, Katherine Chou, Christopher Semturs, Joelle Barral, Greg Corrado and Yossi Matias. We also thank Sami Lachgar, Lauren Winer and John Guilyard for their support with narratives and the visuals. Finally, we are grateful to Michael Howell, James Maynika, Jeff Dean, Karen DeSalvo, Zoubin Gharahmani and Demis Hassabis for their support during the course of this project.


Source: Google AI Blog


AMIE: A research AI system for diagnostic medical reasoning and conversations

The physician-patient conversation is a cornerstone of medicine, in which skilled and intentional communication drives diagnosis, management, empathy and trust. AI systems capable of such diagnostic dialogues could increase availability, accessibility, quality and consistency of care by being useful conversational partners to clinicians and patients alike. But approximating clinicians’ considerable expertise is a significant challenge.

Recent progress in large language models (LLMs) outside the medical domain has shown that they can plan, reason, and use relevant context to hold rich conversations. However, there are many aspects of good diagnostic dialogue that are unique to the medical domain. An effective clinician takes a complete “clinical history” and asks intelligent questions that help to derive a differential diagnosis. They wield considerable skill to foster an effective relationship, provide information clearly, make joint and informed decisions with the patient, respond empathically to their emotions, and support them in the next steps of care. While LLMs can accurately perform tasks such as medical summarization or answering medical questions, there has been little work specifically aimed towards developing these kinds of conversational diagnostic capabilities.

Inspired by this challenge, we developed Articulate Medical Intelligence Explorer (AMIE), a research AI system based on a LLM and optimized for diagnostic reasoning and conversations. We trained and evaluated AMIE along many dimensions that reflect quality in real-world clinical consultations from the perspective of both clinicians and patients. To scale AMIE across a multitude of disease conditions, specialties and scenarios, we developed a novel self-play based simulated diagnostic dialogue environment with automated feedback mechanisms to enrich and accelerate its learning process. We also introduced an inference time chain-of-reasoning strategy to improve AMIE’s diagnostic accuracy and conversation quality. Finally, we tested AMIE prospectively in real examples of multi-turn dialogue by simulating consultations with trained actors.

AMIE was optimized for diagnostic conversations, asking questions that help to reduce its uncertainty and improve diagnostic accuracy, while also balancing this with other requirements of effective clinical communication, such as empathy, fostering a relationship, and providing information clearly.

Evaluation of conversational diagnostic AI

Besides developing and optimizing AI systems themselves for diagnostic conversations, how to assess such systems is also an open question. Inspired by accepted tools used to measure consultation quality and clinical communication skills in real-world settings, we constructed a pilot evaluation rubric to assess diagnostic conversations along axes pertaining to history-taking, diagnostic accuracy, clinical management, clinical communication skills, relationship fostering and empathy.

We then designed a randomized, double-blind crossover study of text-based consultations with validated patient actors interacting either with board-certified primary care physicians (PCPs) or the AI system optimized for diagnostic dialogue. We set up our consultations in the style of an objective structured clinical examination (OSCE), a practical assessment commonly used in the real world to examine clinicians’ skills and competencies in a standardized and objective way. In a typical OSCE, clinicians might rotate through multiple stations, each simulating a real-life clinical scenario where they perform tasks such as conducting a consultation with a standardized patient actor (trained carefully to emulate a patient with a particular condition). Consultations were performed using a synchronous text-chat tool, mimicking the interface familiar to most consumers using LLMs today.

AMIE is a research AI system based on LLMs for diagnostic reasoning and dialogue.

AMIE: an LLM-based conversational diagnostic research AI system

We trained AMIE on real-world datasets comprising medical reasoning, medical summarization and real-world clinical conversations.

It is feasible to train LLMs using real-world dialogues developed by passively collecting and transcribing in-person clinical visits, however, two substantial challenges limit their effectiveness in training LLMs for medical conversations. First, existing real-world data often fails to capture the vast range of medical conditions and scenarios, hindering the scalability and comprehensiveness. Second, the data derived from real-world dialogue transcripts tends to be noisy, containing ambiguous language (including slang, jargon, humor and sarcasm), interruptions, ungrammatical utterances, and implicit references.

To address these limitations, we designed a self-play based simulated learning environment with automated feedback mechanisms for diagnostic medical dialogue in a virtual care setting, enabling us to scale AMIE’s knowledge and capabilities across many medical conditions and contexts. We used this environment to iteratively fine-tune AMIE with an evolving set of simulated dialogues in addition to the static corpus of real-world data described.

This process consisted of two self-play loops: (1) an “inner” self-play loop, where AMIE leveraged in-context critic feedback to refine its behavior on simulated conversations with an AI patient simulator; and (2) an “outer” self-play loop where the set of refined simulated dialogues were incorporated into subsequent fine-tuning iterations. The resulting new version of AMIE could then participate in the inner loop again, creating a virtuous continuous learning cycle.

Further, we also employed an inference time chain-of-reasoning strategy which enabled AMIE to progressively refine its response conditioned on the current conversation to arrive at an informed and grounded reply.

AMIE uses a novel self-play based simulated dialogue learning environment to improve the quality of diagnostic dialogue across a multitude of disease conditions, specialities and patient contexts.

We tested performance in consultations with simulated patients (played by trained actors), compared to those performed by 20 real PCPs using the randomized approach described above. AMIE and PCPs were assessed from the perspectives of both specialist attending physicians and our simulated patients in a randomized, blinded crossover study that included 149 case scenarios from OSCE providers in Canada, the UK and India in a diverse range of specialties and diseases.

Notably, our study was not designed to emulate either traditional in-person OSCE evaluations or the ways clinicians usually use text, email, chat or telemedicine. Instead, our experiment mirrored the most common way consumers interact with LLMs today, a potentially scalable and familiar mechanism for AI systems to engage in remote diagnostic dialogue.

Overview of the randomized study design to perform a virtual remote OSCE with simulated patients via online multi-turn synchronous text chat.

Performance of AMIE

In this setting, we observed that AMIE performed simulated diagnostic conversations at least as well as PCPs when both were evaluated along multiple clinically-meaningful axes of consultation quality. AMIE had greater diagnostic accuracy and superior performance for 28 of 32 axes from the perspective of specialist physicians, and 24 of 26 axes from the perspective of patient actors.

AMIE outperformed PCPs on multiple evaluation axes for diagnostic dialogue in our evaluations.
Specialist-rated top-k diagnostic accuracy. AMIE and PCPs top-k differential diagnosis (DDx) accuracy are compared across 149 scenarios with respect to the ground truth diagnosis (a) and all diagnoses listed within the accepted differential diagnoses (b). Bootstrapping (n=10,000) confirms all top-k differences between AMIE and PCP DDx accuracy are significant with p <0.05 after false discovery rate (FDR) correction.
Diagnostic conversation and reasoning qualities as assessed by specialist physicians. On 28 out of 32 axes, AMIE outperformed PCPs while being comparable on the rest.

Limitations

Our research has several limitations and should be interpreted with appropriate caution. Firstly, our evaluation technique likely underestimates the real-world value of human conversations, as the clinicians in our study were limited to an unfamiliar text-chat interface, which permits large-scale LLM–patient interactions but is not representative of usual clinical practice. Secondly, any research of this type must be seen as only a first exploratory step on a long journey. Transitioning from a LLM research prototype that we evaluated in this study to a safe and robust tool that could be used by people and those who provide care for them will require significant additional research. There are many important limitations to be addressed, including experimental performance under real-world constraints and dedicated exploration of such important topics as health equity and fairness, privacy, robustness, and many more, to ensure the safety and reliability of the technology.


AMIE as an aid to clinicians

In a recently released preprint, we evaluated the ability of an earlier iteration of the AMIE system to generate a DDx alone or as an aid to clinicians. Twenty (20) generalist clinicians evaluated 303 challenging, real-world medical cases sourced from the New England Journal of Medicine (NEJM) ClinicoPathologic Conferences (CPCs). Each case report was read by two clinicians randomized to one of two assistive conditions: either assistance from search engines and standard medical resources, or AMIE assistance in addition to these tools. All clinicians provided a baseline, unassisted DDx prior to using the respective assistive tools.

Assisted randomized reader study setup to investigate the assistive effect of AMIE to clinicians in solving complex diagnostic case challenges from the New England Journal of Medicine.

AMIE exhibited standalone performance that exceeded that of unassisted clinicians (top-10 accuracy 59.1% vs. 33.6%, p= 0.04). Comparing the two assisted study arms, the top-10 accuracy was higher for clinicians assisted by AMIE, compared to clinicians without AMIE assistance (24.6%, p<0.01) and clinicians with search (5.45%, p=0.02). Further, clinicians assisted by AMIE arrived at more comprehensive differential lists than those without AMIE assistance.

In addition to strong standalone performance, using the AMIE system led to significant assistive effect and improvements in diagnostic accuracy of the clinicians in solving these complex case challenges.

It's worth noting that NEJM CPCs are not representative of everyday clinical practice. They are unusual case reports in only a few hundred individuals so offer limited scope for probing important issues like equity or fairness.


Bold and responsible research in healthcare — the art of the possible

Access to clinical expertise remains scarce around the world. While AI has shown great promise in specific clinical applications, engagement in the dynamic, conversational diagnostic journeys of clinical practice requires many capabilities not yet demonstrated by AI systems. Doctors wield not only knowledge and skill but a dedication to myriad principles, including safety and quality, communication, partnership and teamwork, trust, and professionalism. Realizing these attributes in AI systems is an inspiring challenge that should be approached responsibly and with care. AMIE is our exploration of the “art of the possible”, a research-only system for safely exploring a vision of the future where AI systems might be better aligned with attributes of the skilled clinicians entrusted with our care. It is early experimental-only work, not a product, and has several limitations that we believe merit rigorous and extensive further scientific studies in order to envision a future in which conversational, empathic and diagnostic AI systems might become safe, helpful and accessible.


Acknowledgements

The research described here is joint work across many teams at Google Research and Google Deepmind. We are grateful to all our co-authors - Tao Tu, Mike Schaekermann, Anil Palepu, Daniel McDuff, Jake Sunshine, Khaled Saab, Jan Freyberg, Ryutaro Tanno, Amy Wang, Brenna Li, Mohamed Amin, Sara Mahdavi, Karan Sighal, Shekoofeh Azizi, Nenad Tomasev, Yun Liu, Yong Cheng, Le Hou, Albert Webson, Jake Garrison, Yash Sharma, Anupam Pathak, Sushant Prakash, Philip Mansfield, Shwetak Patel, Bradley Green, Ewa Dominowska, Renee Wong, Juraj Gottweis, Dale Webster, Katherine Chou, Christopher Semturs, Joelle Barral, Greg Corrado and Yossi Matias. We also thank Sami Lachgar, Lauren Winer and John Guilyard for their support with narratives and the visuals. Finally, we are grateful to Michael Howell, James Maynika, Jeff Dean, Karen DeSalvo, Zoubin Gharahmani and Demis Hassabis for their support during the course of this project.

Source: Google AI Blog