Tag Archives: machine learning

MediaPipe on the Web

Posted by Michael Hays and Tyler Mullen from the MediaPipe team

MediaPipe is a framework for building cross-platform multimodal applied ML pipelines. We have previously demonstrated building and running ML pipelines as MediaPipe graphs on mobile (Android, iOS) and on edge devices like Google Coral. In this article, we are excited to present MediaPipe graphs running live in the web browser, enabled by WebAssembly and accelerated by XNNPack ML Inference Library. By integrating this preview functionality into our web-based Visualizer tool, we provide a playground for quickly iterating over a graph design. Since everything runs directly in the browser, video never leaves the user’s computer and each iteration can be immediately tested on a live webcam stream (and soon, arbitrary video).

Running the MediaPipe face detection example in the Visualizer

Figure 1 shows the running of the MediaPipe face detection example in the Visualizer

MediaPipe Visualizer

MediaPipe Visualizer (see Figure 2) is hosted at viz.mediapipe.dev. MediaPipe graphs can be inspected by pasting graph code into the Editor tab or by uploading that graph file into the Visualizer. A user can pan and zoom into the graphical representation of the graph using the mouse and scroll wheel. The graph will also react to changes made within the editor in real time.

MediaPipe Visualizer hosted at https://viz.mediapipe.dev

Figure 2 MediaPipe Visualizer hosted at https://viz.mediapipe.dev

Demos on MediaPipe Visualizer

We have created several sample Visualizer demos from existing MediaPipe graph examples. These can be seen within the Visualizer by visiting the following addresses in your Chrome browser:

Edge Detection

Face Detection

Hair Segmentation

Hand Tracking

Edge detection
Face detection
Hair segmentation
Hand tracking

Each of these demos can be executed within the browser by clicking on the little running man icon at the top of the editor (it will be greyed out if a non-demo workspace is loaded):

This will open a new tab which will run the current graph (this requires a web-cam).

Implementation Details

In order to maximize portability, we use Emscripten to directly compile all of the necessary C++ code into WebAssembly, which is a special form of low-level assembly code designed specifically for web browsers. At runtime, the web browser creates a virtual machine in which it can execute these instructions very quickly, much faster than traditional JavaScript code.

We also created a simple API for all necessary communications back and forth between JavaScript and C++, to allow us to change and interact with the MediaPipe graph directly from JavaScript. For readers familiar with Android development, you can think of this as a similar process to authoring a C++/Java bridge using the Android NDK.

Finally, we packaged up all the requisite demo assets (ML models and auxiliary text/data files) as individual binary data packages, to be loaded at runtime. And for graphics and rendering, we allow MediaPipe to automatically tap directly into WebGL so that most OpenGL-based calculators can “just work” on the web.

Performance

While executing WebAssembly is generally much faster than pure JavaScript, it is also usually much slower than native C++, so we made several optimizations in order to provide a better user experience. We utilize the GPU for image operations when possible, and opt for using the lightest-weight possible versions of all our ML models (giving up some quality for speed). However, since compute shaders are not widely available for web, we cannot easily make use of TensorFlow Lite GPU machine learning inference, and the resulting CPU inference often ends up being a significant performance bottleneck. So to help alleviate this, we automatically augment our “TfLiteInferenceCalculator” by having it use the XNNPack ML Inference Library, which gives us a 2-3x speedup in most of our applications.

Currently, support for web-based MediaPipe has some important limitations:

  • Only calculators in the demo graphs above may be used
  • The user must edit one of the template graphs; they cannot provide their own from scratch
  • The user cannot add or alter assets
  • The executor for the graph must be single-threaded (i.e. ApplicationThreadExecutor)
  • TensorFlow Lite inference on GPU is not supported

We plan to continue to build upon this new platform to provide developers with much more control, removing many if not all of these limitations (e.g. by allowing for dynamic management of assets). Please follow the MediaPipe tag on the Google Developer blog and Google Developer twitter account. (@googledevs)

Acknowledgements

We would like to thank Marat Dukhan, Chuo-Ling Chang, Jianing Wei, Ming Guang Yong, and Matthias Grundmann for contributing to this blog post.

Reformer: The Efficient Transformer



Understanding sequential data — such as language, music or videos — is a challenging task, especially when there is dependence on extensive surrounding context. For example, if a person or an object disappears from view in a video only to re-appear much later, many models will forget how it looked. In the language domain, long short-term memory (LSTM) neural networks cover enough context to translate sentence-by-sentence. In this case, the context window (i.e., the span of data taken into consideration in the translation) covers from dozens to about a hundred words. The more recent Transformer model not only improved performance in sentence-by-sentence translation, but could be used to generate entire Wikipedia articles through multi-document summarization. This is possible because the context window used by Transformer extends to thousands of words. With such a large context window, Transformer could be used for applications beyond text, including pixels or musical notes, enabling it to be used to generate music and images.

However, extending Transformer to even larger context windows runs into limitations. The power of Transformer comes from attention, the process by which it considers all possible pairs of words within the context window to understand the connections between them. So, in the case of a text of 100K words, this would require assessment of 100K x 100K word pairs, or 10 billion pairs for each step, which is impractical. Another problem is with the standard practice of storing the output of each model layer. For applications using large context windows, the memory requirement for storing the output of multiple model layers quickly becomes prohibitively large (from gigabytes with a few layers to terabytes in models with thousands of layers). This means that realistic Transformer models, using numerous layers, can only be used on a few paragraphs of text or generate short pieces of music.

Today, we introduce the Reformer, a Transformer model designed to handle context windows of up to 1 million words, all on a single accelerator and using only 16GB of memory. It combines two crucial techniques to solve the problems of attention and memory allocation that limit Transformer’s application to long context windows. Reformer uses locality-sensitive-hashing (LSH) to reduce the complexity of attending over long sequences and reversible residual layers to more efficiently use the memory available.

The Attention Problem
The first challenge when applying a Transformer model to a very large text sequence is how to handle the attention layer. LSH accomplishes this by computing a hash function that matches similar vectors together, instead of searching through all possible pairs of vectors. For example, in a translation task, where each vector from the first layer of the network represents a word (even larger contexts in subsequent layers), vectors corresponding to the same words in different languages may get the same hash. In the figure below, different colors depict different hashes, with similar words having the same color. When the hashes are assigned, the sequence is rearranged to bring elements with the same hash together and divided into segments (or chunks) to enable parallel processing. Attention is then applied within these much shorter chunks (and their adjoining neighbors to cover the overflow), greatly reducing the computational load.
Locality-sensitive-hashing: Reformer takes in an input sequence of keys, where each key is a vector representing individual words (or pixels, in the case of images) in the first layer and larger contexts in subsequent layers. LSH is applied to the sequence, after which the keys are sorted by their hash and chunked. Attention is applied only within a single chunk and its immediate neighbors.
The Memory Problem
While LSH solves the problem with attention, there is still a memory issue. A single layer of a network often requires up to a few GB of memory and usually fits on a single GPU, so even a model with long sequences could be executed if it only had one layer. But when training a multi-layer model with gradient descent, activations from each layer need to be saved for use in the backward pass. A typical Transformer model has a dozen or more layers, so memory quickly runs out if used to cache values from each of those layers.

The second novel approach implemented in Reformer is to recompute the input of each layer on-demand during back-propagation, rather than storing it in memory. This is accomplished by using reversible layers, where activations from the last layer of the network are used to recover activations from any intermediate layer, by what amounts to running the network in reverse. In a typical residual network, each layer in the stack keeps adding to vectors that pass through the network. Reversible layers, instead, have two sets of activations for each layer. One follows the standard procedure just described and is progressively updated from one layer to the next, but the other captures only the changes to the first. Thus, to run the network in reverse, one simply subtracts the activations applied at each layer.
Reversible layers: (A) In a standard residual network, the activations from each layer are used to update the inputs into the next layer. (B) In a reversible network, two sets of activations are maintained, only one of which is updated after each layer. (C) This approach enables running the network in reverse in order to recover all intermediate values.
Applications of Reformer
The novel application of these two approaches in Reformer makes it highly efficient, enabling it to process text sequences of lengths up to 1 million words on a single accelerator using only 16GB of memory. Since Reformer has such high efficiency, it can be applied directly to data with context windows much larger than virtually all current state-of-the-art text domain datasets. Perhaps Reformer’s ability to deal with such large datasets will stimulate the community to create them.

One area where there is no shortage of large-context data is image generation, so we experiment with the Reformer on images. In this colab, we present examples of how Reformer can be used to “complete” partial images. Starting with the image fragments shown in the top row of the figure below, Reformer can generate full frame images (bottom row), pixel-by-pixel.
Top: Image fragments used as input to Reformer. Bottom: “Completed” full-frame images. Original images are from the Imagenet64 dataset.
While the application of Reformer to imaging and video tasks shows great potential, its application to text is even more exciting. Reformer can process entire novels, all at once and on a single device. Processing the entirety of Crime and Punishment in a single training example is demonstrated in this colab. In the future, when there are more datasets with long-form text to train, techniques such as the Reformer may make it possible to generate long coherent compositions.

Conclusion
We believe Reformer gives the basis for future use of Transformer models, both for long text and applications outside of natural language processing. Following our tradition of doing research in the open, we have already started exploring how to apply it to even longer sequences and how to improve handling of positional encodings. Read the Reformer paper (selected for oral presentation at ICLR 2020), explore our code and develop your own ideas too. Few long-context datasets are widely used in deep learning yet, but in the real world long context is everywhere. Maybe you can find a new application for Reformer — start with this colab and chat with us if you have any problems or questions!

Acknowledgements
This research was conducted by Nikita Kitaev, Łukasz Kaiser and Anselm Levskaya. Additional thanks go to Afroz Mohiuddin, Jonni Kanerva and Piotr Kozakowski for their work on Trax and to the whole JAX team for their support.

Source: Google AI Blog


Using Machine Learning to “Nowcast” Precipitation in High Resolution



The weather can affect a person’s daily routine in both mundane and serious ways, and the precision of forecasting can strongly influence how they deal with it. Weather predictions can inform people about whether they should take a different route to work, if they should reschedule the picnic planned for the weekend, or even if they need to evacuate their homes due to an approaching storm. But making accurate weather predictions can be particularly challenging for localized storms or events that evolve on hourly timescales, such as thunderstorms.

In “Machine Learning for Precipitation Nowcasting from Radar Images,” we are presenting new research into the development of machine learning models for precipitation forecasting that addresses this challenge by making highly localized “physics-free” predictions that apply to the immediate future. A significant advantage of machine learning is that inference is computationally cheap given an already-trained model, allowing forecasts that are nearly instantaneous and in the native high resolution of the input data. This precipitation nowcasting, which focuses on 0-6 hour forecasts, can generate forecasts that have a 1km resolution with a total latency of just 5-10 minutes, including data collection delays, outperforming traditional models, even at these early stages of development.

Moving Beyond Traditional Weather Forecasting
Weather agencies around the world have extensive monitoring facilities. For example, Doppler radar measures precipitation in real-time, weather satellites provide multispectral imaging, ground stations measure wind and precipitation directly, etc. The figure below, which compares false-color composite radar imaging of precipitation over the continental US to cloud cover imaged by geosynchronous satellites, illustrates the need for multi-source weather information. The existence of rain is related to, but not perfectly correlated with, the existence of clouds, so inferring precipitation from satellite images alone is challenging.
Top: Image showing the location of clouds as measured by geosynchronous satellites. Bottom: Radar image showing the location of rain as measured by Doppler radar stations. (Credit: NOAA, NWS, NSSL)
Unfortunately, not all of these measurements are equally present across the globe. For example, radar data comes largely from ground stations and is generally not available over the oceans. Further, coverage varies geographically, and some locations may have poor radar coverage even when they have good satellite coverage.

Even so, there is so much observational data in so many different varieties that forecasting systems struggle to incorporate it all. In the US, remote sensing data collected by the National Oceanic and Atmospheric Administration (NOAA) is now reaching 100 terabytes per day. NOAA uses this data to feed the massive weather forecasting engines that run on supercomputers to provide 1- to 10-day global forecasts. These engines have been developed over the course of the last half century, and are based on numerical methods that directly simulate physical processes, including atmospheric dynamics and numerous effects like thermal radiation, vegetation, lake and ocean effects, and more.

However, the availability of computational resources limits the power of numerical weather prediction in several ways. For example, computational demands limit the spatial resolution to about 5 kilometers, which is not sufficient for resolving weather patterns within urban areas and agricultural land. Numerical methods also take multiple hours to run. If it takes 6 hours to compute a forecast, that allows only 3-4 runs per day and resulting in forecasts based on 6+ hour old data, which limits our knowledge of what is happening right now. By contrast, nowcasting is especially useful for immediate decisions from traffic routing and logistics to evacuation planning.

Radar-to-Radar Forecasting
As a typical example of the type of predictions our system can generate, consider the radar-to-radar forecasting problem: given a sequence of radar images for the past hour, predict what the radar image will be N hours from now, where N typically ranges from 0-6 hours. Since radar data is organized into images, we can pose this prediction as a computer vision problem, inferring the meteorological evolution from the sequence of input images. At these short timescales, the evolution is dominated by two physical processes: advection for the cloud motion, and convection for cloud formation, both of which are significantly affected by local terrain and geography.
Top (left to right): The first three panels show radar images from 60 minutes, 30 minutes, and 0 minutes before now, the point at which a prediction is desired. The right-most panel shows the radar image 60 minutes after now, i.e., the ground truth for a nowcasting prediction. Bottom Left: For comparison, a vector field induced from applying an optical flow (OF) algorithm for modeling advection to the data from the first three panels above. Optical flow is a computer vision method that was developed in the 1940s, and is frequently used to predict short term weather evolution. Bottom Right: An example prediction made by OF. Notice that it tracks the motion of the precipitation in the bottom left corner well, but fails to account for the decaying strength of the storm.
We use a data-driven physics-free approach, meaning that the neural network will learn to approximate the atmospheric physics from the training examples alone, not by incorporating a priori knowledge of how the atmosphere actually works. We treat weather prediction as an image-to-image translation problem, and leverage the current state-of-the-art in image analysis: convolutional neural networks (CNNs).

CNNs are usually composed of a linear sequence of layers, where each layer is a set of operations that transform some input image into a new output image. Often, a layer will change the number of channels and the overall resolution of the image it’s given, in addition to convolving the image with a set of convolutional filters. These filters are themselves small images (for us, they are typically only 3x3, or 5x5). Filters drive much of the power of CNNs, and result in operations like detecting edges, identifying meaningful patterns, etc.

A particularly effective type of CNN is the U-Net. U-Nets have a sequence of layers that are arranged in an encoding phase, in which layers iteratively decrease the resolution of the images passing through them, and then a decoding phase in which the low-dimensional representations of the image created by the encoding phase are expanded back to higher resolutions. The following figure shows all of the layers in our particular U-Net.
(A) The overall structure of our U-NET. Blue boxes correspond to basic CNN layers. Pink boxes correspond to down-sample layers. Green boxes correspond to up-sample layers. Solid lines indicate input connections between layers. Dashed lines indicate long skip connections transversing the encoding and decoding phases of the U-NET. Dotted lines indicate short skip connections for individual layers. (B) The operations within our basic layer. (C) The operations within our down-sample layers. (D) The operations within our up-sample layers.
The input to the U-Net is an image that contains one channel for each multispectral satellite image in the sequence of observations over the last hour. For example, if there were 10 satellite images collected in the last hour, and each of those multispectral images was taken at 10 different wavelengths, then the image input for our model would be an image with 100 channels. For radar-to-radar forecasting, the input is a sequence of 30 radar observations over the past hour, spaced 2 minutes apart, and the output contains the prediction for N hours from now. For our initial work in the US, we trained a network from historical observations over the continental US from the period between 2017 and 2019. The data is split into periods of four weeks, where the first three weeks of each period are used for training and the fourth week is used for evaluation.

Results
We compare our results with three widely used models. First, the High Resolution Rapid Refresh (HRRR) numerical forecast from NOAA. HRRR actually contains predictions for many different weather quantities. We compared our results to their 1-hour total accumulated surface precipitation prediction, as that was their highest quality 1-hour precipitation prediction. Second, an optical flow (OF) algorithm, which attempts to track moving objects through a sequence of images. This latter approach is often applied to weather prediction even though it makes the assumption that overall rain quantities over large areas are constant over the prediction time — an assumption that is clearly violated. Third, the so-called persistence model, is the trivial model in which each location is assumed to be raining in the future at the same rate it is raining now, i.e. the precipitation pattern does not change. That may seem like an overly simplistic model to compare to, but it is common practice given the difficulty of weather prediction.
A visualization of predictions made over the course of roughly one day. Left: The 1-hour HRRR prediction made at the top of each hour, the limit to how often HRRR provides predictions. Center: The ground truth, i.e., what we are trying to predict. Right: The predictions made by our model. Our predictions are every 2 minutes (displayed here every 15 minutes) at roughly 10 times the spatial resolution made by HRRR. Notice that we capture the general motion and general shape of the storm.
We use precision and recall (PR) graphs to compare the models. Since we have direct access to our own classifier, we provide a full PR curve (seen as the blue line in the figure below). However, since we don’t have direct access to the HRRR model, and since neither the persistence model nor OF have the ability to trade-off precision and recall, those models are represented only by individual points. As can be seen, the quality of our neural network forecast outperforms all three of these models (since the blue line is above all of the other model’s results). It is important to note, however, that the HRRR model begins to outperform our current results when the prediction horizon reaches roughly 5 to 6 hours.
Precision and recall (PR) curves comparing our results (solid blue line) with: optical flow (OF), the persistence model, and the HRRR 1-hour prediction. As we do not have direct access to their classifiers, we cannot provide a full PR curve for their results. Left: Predictions for light rain. Right: Predictions for moderate rain.
One of the advantages of the ML method is that predictions are effectively instantaneous, meaning that our forecasts are based on fresh data, while HRRR is hindered by computational latency of 1-3 hours. This leads to better forecasts for computer vision methods for very short term forecasting. In contrast, the numerical model used in HRRR can make better long term predictions, in part because it uses a full 3D physical model — cloud formation is harder to observe from 2D images, and so it is harder for ML methods to learn convective processes. It's possible that combining these two systems, our ML model for rapid forecasts and HRRR for long-term forecasts, could produce better results overall, an idea at the focus of our future work. We're also looking at applying ML directly to 3D observations. Regardless, immediate forecasting is a key tool for real-time planning, facilitating decisions and improving lives.

Acknowledgements
Thanks to Carla Bromberg, Shreya Agrawal, Cenk Gazen, John Burge, Luke Barrington, Aaron Bell, Anand Babu, Stephan Hoyer, Lak Lakshmanan, Brian Williams, Casper Sønderby, Nal Kalchbrenner, Avital Oliver, Tim Salimans, Mostafa Dehghani, Jonathan Heek, Lasse Espeholt, Sella Nevo, Avinatan Hassidim.

Source: Google AI Blog


The On-Device Machine Learning Behind Recorder



Over the past two decades, Google has made information widely accessible through search — from textual information, photos and videos, to maps and jobs. But much of the world’s information is conveyed through speech. Yet even though many people use audio recording devices to capture important information in conversations, interviews, lectures and more, it can be very difficult to later parse through hours of recordings to identify and extract information of interest. But what if there was the ability to automatically transcribe and tag long recordings in real-time, enabling you to intuitively find the relevant information you need, when you need it?

For this reason, we launched Recorder, a new kind of audio recording app for Pixel phones that leverages recent developments in on-device machine learning (ML) to transcribe conversations, to detect and identify the type of audio recorded (from broad categories like music or speech to particular sounds, such as applause, laughter and whistling), and to index recordings so users can quickly find and extract segments of interest. All of these features run entirely on-device, without the need for an internet connection.
Transcription
Recorder transcribes speech in real-time using an on-device automatic speech recognition model based on improvements announced earlier this year. Being a key component to many of Recorder’s smart features, we made sure that this model can transcribe long audio recordings (a few hours) reliably, while also indexing conversation by mapping words to timestamps as computed by the speech recognition model. This enables the user to click on a word in the transcription and initiate playback starting from that point in the recording, or to search for a word and jump to the exact point in the recording where it was being said.
Recording Content Visualization via Sound Classification
While presenting a transcript for a recording is useful and allows one to search for specific words, sometimes (especially for very long recordings) it’s more useful to visually search for sections of a recording based on specific moments or sounds. To enable this, Recorder additionally represents audio visually as a colored waveform where each color is associated with a different sound category. This is done by combining research into using CNNs to classify audio sounds (e.g., identifying a dog barking or a musical instrument playing) with previously published datasets for audio event detection to classify apparent sound events in individual audio frames.

Of course, in most situations many sounds can appear at the same time. In order to visualize the audio in a very clear way, we decided to color each waveform bar in a single color that represents the most dominant sound in a given time frame (in our case, 50ms bars). The colorized waveform lets users understand what type of content was captured in a specific recording and navigate along an ever-growing audio library more easily. This brings a visual representation of the audio recordings to the users, and also enables them to search over audio events in their recordings.
Recorder implements a sliding window capability that processes partially overlapping 960ms audio frames at 50ms intervals and outputs a sigmoid scores vector, representing the probability for each supported audio class within the frame. We apply a linearization process on the sigmoid scores in combination with a thresholding mechanism, in order to maximize the system precision and report the correct sound classification. This process of analyzing the content of the 960ms window with small 50ms offsets makes it possible to pinpoint exact start and end times in a manner that is less prone to mistakes than analyzing consecutive large 960ms window slices on their own.
Since the model analyzes each audio frame independently, it can be prone to quick jittering between audio classes. This is solved with an adaptive-size median filtering technique applied to the most recent model audio class outputs, thus providing a smoothed consecutive output. The process runs continuously in real-time, requiring it to meet very strict power consumption limitations.

Suggesting Tags for Titles
Once a recording is done, Recorder suggests three tags that the app deems to represent the most memorable content, enabling the user to quickly compose a meaningful title.
To be able to suggest these tags immediately when the recording ends, Recorder analyzes the content of the recording as it is being transcribed. First, Recorder counts term occurrences as well as their grammatical role in the sentence. The terms identified as entities are capitalized. Then, we utilize an on-device part-of-speech-tagger — a model that labels each word in the sentence according to its grammatical role — to detect common nouns and proper nouns, which appear to be more memorable by users. Recorder utilizes a prior scores table supporting both unigram and bigram terms extraction. To generate the scores, we trained a boosted decision tree with conversational data and utilized textual features like document words frequency and specificity. Last, filtering of stop words and swear words is applied and the top tags are outputted.
Tags extraction pipeline architecture
Conclusion
Recorder galvanized some of our most recent on-device ML research efforts into helpful features, running models on-device to ensure user privacy. The positive feedback loop between machine learning investigations and user needs revealed exciting opportunities to make our software even more useful. We’re excited for future research that will make everyone’s ideas and conversations even more easily accessible and searchable.

Acknowledgments
Special thanks to Dror Ayalon who played a key role in developing and forming the above features and without whom this blog post wouldn’t have been possible. We would also want to thank all our team members and collaborators who worked on this project with us: Amit Pitaru, Kelsie Van Deman, Isaac Blankensmith, Teo Soares, John Watkinson, Matt Hall, Josh Deitel, Benny Schlesinger, Yoni Tsafir, Michelle Tadmor Ramanovich, Danielle Cohen, Sushant Prakash, Renat Aksitov, Ed West, Max Gubin, Tiantian Zhang, Aaron Cohen, Yunhsuan Sung, Chung-Ching Chang, Nathan Dass, Amin Ahmad, Tiago Camolesi, Guilherme Santos‎, Julio da Silva, Dan Ellis, Qiao Liang, Arun Narayanan‎, Rohit Prabhavalkar, Benyah Shaparenko‎, Alex Salcianu, Mike Tsao, Shenaz Zak, Sherry Lin, James Lemieux, Jason Cho, Thomas Hall‎, Brian Chen, Allen Su, Vincent Peng‎, Richard Chou‎, Henry Liu‎, Edward Chen, Yitong Lin, Tracy Wu, Yvonne Yang‎.

Source: Google AI Blog


Improving Out-of-Distribution Detection in Machine Learning Models



Successful deployment of machine learning systems requires that the system be able to distinguish between data that is anomalous or significantly different from that used in training. This is particularly important for deep neural network classifiers, which might classify such out-of-distribution (OOD) inputs into in-distribution classes with high confidence. This is critically important when these predictions inform real-world decisions.

For example, one challenging application of machine learning models to real-world applications is bacteria identification based on genomic sequences. Bacteria detection is crucial for diagnosis and treatment of infectious diseases, such as sepsis, and for identifying foodborne pathogens. New bacterial classes continue to be discovered over the years, and while a neural network classifier trained on the known classes achieves high accuracy as measured through cross-validation, deploying a model is challenging, since real-world data is ever evolving and will inevitably contain genomes from unseen classes (OOD inputs) not present in the training data.
New bacterial classes are gradually discovered over the years. A classifier trained on known classes achieves high accuracy for test inputs belonging to known classes, but can wrongly classify inputs from unknown classes (i.e., out-of-distribution) into known classes with high confidence.
In “Likelihood Ratios for Out-of-Distribution Detection”, presented at NeurIPS 2019, we proposed and released a realistic benchmark dataset of genomic sequences for OOD detection that is inspired by the real-world challenges described above. We tested existing methods for OOD detection using generative models on genomic sequences and found that the likelihood values — i.e., the model's probability that an input comes from the distribution as estimated using in-distribution data — was often in error. This phenomenon has also been observed in recent work on deep generative models of images. We explain this phenomenon through the effect of background statistics and propose a likelihood-ratio based solution that significantly improves the accuracy of OOD detection.

Why Do Density Models Fail At OOD Detection?
To mimic the real problem and systematically evaluate different methods, we built a new bacterial dataset using data sourced from the publicly available NCBI catalog of prokaryotic genome sequences. To mimic sequencing data, we fragmented genomes into short sequences of 250 base pairs, a length commonly generated by current sequencing technology. We then separated in- and out-of-distribution data by the date of discovery, such that bacterial classes discovered before a cutoff time were defined as in-distribution, and those discovered afterward as OOD.

We then trained a deep generative model on in-distribution genomic sequences and examined how well the model discriminated between in- and out-of-distribution inputs by plotting their likelihood values. The histogram of the likelihood for OOD sequences largely overlaps with that of in-distribution sequences, indicating that the generative model was unable to distinguish between the two populations for OOD detection. Similar results were shown in earlier work for deep generative models of images — for instance, a PixelCNN++ model trained on images from Fashion-MNIST dataset (which consists of images of clothing and footwear) assigns higher likelihood to OOD images from the MNIST dataset (which consists of images of digits 0-9).
Left: Histogram of likelihood values for in- and out-of-distribution (OOD) genomic sequences. The likelihood fails to separate in-distribution and OOD genomic sequences. Right: A similar plot for a model trained on Fashion-MNIST and evaluated on MNIST. The model assigns higher likelihood values for OOD (MNIST) than in-distribution images.
When investigating this failure mode, we observed that the likelihood can be confounded by background statistics. To understand the phenomenon more intuitively, assume that an input is composed of two components, (1) a background component characterized by background statistics, and (2) a semantic component characterized by patterns specific to the in-distribution data. For example, an MNIST image can be modeled as background plus semantics. When humans interpret the image, we can easily ignore the background and focus primarily on the semantic information, e.g., the “/” mark in the image below. But the likelihood is calculated for all pixels in an image, including both semantic and background pixels. Though we want to use just the semantic likelihood for decision making, the raw likelihood can be dominated by background.
Left top: Sample images from Fashion-MNIST. Left bottom: Sample images from MNIST. Right: Background and semantic components in an MNIST image.
Likelihood Ratios For OOD Detection
We propose a likelihood ratio method that removes the effect of background and focuses on semantics. First, we train a background model on perturbed inputs. The method for perturbing the input is inspired by genetic mutations, and proceeds by randomly selecting positions in the input and substituting the value with another that has equal probability. For imaging, the values are randomly chosen from the 256 possible pixel values, and for the DNA sequences, the value is selected from the four possible nucleotides (A, T, C, or G). The right amount of perturbation can corrupt the semantic structure in the data, and captures only the background. Then we compute the likelihood ratio between the full model and the background model, and the background component is cancelled out, so that only the likelihood for semantics remains. Likelihood ratio is a background contrastive score, i.e., it captures the significance of the semantics compared to the background.

To qualitatively evaluate the difference between the likelihood and likelihood ratio, we plotted their values for each pixel in the Fashion-MNIST and MNIST datasets, creating heatmaps that have the same size as the images. This allows us to visualize which pixels contribute the most to the two terms, respectively. From the log-likelihood heatmaps, we see that the background pixels contribute much more to the likelihood than the semantic pixels. In hindsight, this is not surprising, since background pixels consist mostly of a string of zeros, a pattern very easily learned by the model. A comparison between the MNIST and Fashion-MNIST heatmaps demonstrates why MNIST returns higher likelihood values — it simply has a lot more background pixels! The likelihood ratio instead focuses more on the semantic pixels.
Left: Log-likelihood heatmaps for Fashion-MNIST and MNIST datasets. Right: The same examples showing heatmaps of the likelihood-ratio. Pixels with higher values are of lighter shades. The likelihood is dominated by the “background” pixels, whereas the likelihood ratio focuses on the “semantic” pixels and is thus better for OOD detection.
Our likelihood ratio method corrects the background effect and significantly improves the OOD detection of MNIST images from an AUROC score of 0.089 to 0.994, based on a PixelCNN++ model trained for Fashion-MNIST. When applied to the genomic benchmark dataset, this method achieves state-of-the-art performance on this challenging problem, when compared to 12 other baseline methods.

For more details, please check out our recent paper at NeurIPS 2019. While our likelihood ratio method reaches state-of-the-art performance on the genomic dataset, it does not yet have high enough accuracy to reach the standards for deployment of the model to real applications. We encourage researchers to contribute their solutions to this important problem and improve the current state-of-the-art. The dataset is available on our GitHub repository.

Acknowledgments
The work described here was authored by Jie Ren, Peter J. Liu, Emily Fertig, Jasper Snoek, Ryan Poplin, Mark A. DePristo, Joshua V. Dillon, Balaji Lakshminarayanan, through a collaboration spanning several teams across Google AI and DeepMind. We are grateful for all the discussions and feedback on this work that we received from the reviewers at NeurIPS 2019, and our colleagues at Google and DeepMind: Alexander A. Alemi, Andreea Gane, Brian Lee, D. Sculley, Eric Jang, Jacob Burnim, Katherine Lee, Matthew D. Hoffman, Noah Fiedel, Rif A. Saurous, Suman Ravuri, Thomas Colthurst, Yaniv Ovadia, along with the Google Brain and TensorFlow teams.

Source: Google AI Blog


Fairness Indicators: Scalable Infrastructure for Fair ML Systems



While industry and academia continue to explore the benefits of using machine learning (ML) to make better products and tackle important problems, algorithms and the datasets on which they are trained also have the ability to reflect or reinforce unfair biases. For example, consistently flagging non-toxic text comments from certain groups as “spam” or “high toxicity” in a moderation system leads to exclusion of those groups from conversation.

In 2018, we shared how Google uses AI to make products more useful, highlighting AI principles that will guide our work moving forward. The second principle, “Avoid creating or reinforcing unfair bias,” outlines our commitment to reduce unjust biases and minimize their impacts on people.

As part of this commitment, at TensorFlow World, we recently released a beta version of Fairness Indicators, a suite of tools that enable regular computation and visualization of fairness metrics for binary and multi-class classification, helping teams take a first step towards identifying unjust impacts. Fairness Indicators can be used to generate metrics for transparency reporting, such as those used for model cards, to help developers make better decisions about how to deploy models responsibly. Because fairness concerns and evaluations differ case by case, we also include in this release an interactive case study with Jigsaw’s Unintended Bias in Toxicity dataset to illustrate how Fairness Indicators can be used to detect and remediate bias in a production machine learning (ML) model, depending on the context in which it is deployed. Fairness Indicators is now available in beta for you to try for your own use cases.

What is ML Fairness?
Bias can manifest in any part of a typical machine learning pipeline, from an unrepresentative dataset, to learned model representations, to the way in which the results are presented to the user. Errors that result from this bias can disproportionately impact some users more than others.

To detect this unequal impact, evaluation over individual slices, or groups of users, is crucial as overall metrics can obscure poor performance for certain groups. These groups may include, but are not limited to, those defined by sensitive characteristics such as race, ethnicity, gender, nationality, income, sexual orientation, ability, and religious belief. However, it is also important to keep in mind that fairness cannot be achieved solely through metrics and measurement; high performance, even across slices, does not necessarily prove that a system is fair. Rather, evaluation should be viewed as one of the first ways, especially for classification models, to identify gaps in performance.

The Fairness Indicators Suite of Tools
The Fairness Indicators tool suite enables computation and visualization of commonly-identified fairness metrics for classification models, such as false positive rate and false negative rate, making it easy to compare performance across slices or to a baseline slice. The tool computes confidence intervals, which can surface statistically significant disparities, and performs evaluation over multiple thresholds. In the UI, it is possible to toggle the baseline slice and investigate the performance of various other metrics. The user can also add their own metrics for visualization, specific to their use case.

Furthermore, Fairness Indicators is integrated with the What-If Tool (WIT) — clicking on a bar in the Fairness Indicators graph will load those specific data points into the the WIT widget for further inspection, comparison, and counterfactual analysis. This is particularly useful for large datasets, where Fairness Indicators can be used to identify problematic slices before the WIT is used for a deeper analysis.
Using Fairness Indicators to visualize metrics for fairness evaluation.
Clicking on a slice in Fairness Indicators will load all the data points in that slice inside the What-If Tool widget. In this case, all data points with the “female” label are shown.
The Fairness Indicators beta launch includes the following:
How To Use Fairness Indicators in Models Today
Fairness Indicators is built on top of TensorFlow Model Analysis, a component of TensorFlow Extended (TFX) that can be used to investigate and visualize model performance. Based on the specific ML workflow, Fairness Indicators can be incorporated into a system in one of the following ways:
If using TensorFlow models and tools, such as TFX:
  • Access Fairness Indicators as part of the Evaluator component in TFX
  • Access Fairness Indicators in TensorBoard when evaluating other real-time metrics
If not using existing TensorFlow tools:
  • Download the Fairness Indicators pip package, and use Tensorflow Model Analysis as a standalone tool
For non-TensorFlow models:
Fairness Indicators Case Study
We created a case study and introductory video that illustrates how Fairness Indicators can be used with a combination of tools to detect and mitigate bias in a model trained on Jigsaw’s Unintended Bias in Toxicity dataset. The dataset was developed by Conversation AI, a team within Jigsaw that works to train ML models to protect voices in conversation. Models are trained to predict whether text comments are likely to be abusive along a variety of dimensions including toxicity, insult, and sexual explicitness.

The primary use case for models such as these is content moderation. If a model penalizes certain types of messages in a systematic way (e.g., often marks comments as toxic when they are not, leading to a high false positive rate), those voices will be silenced. In the case study, we investigated false positive rate on subgroups sliced by gender identity keywords that are present in the dataset, using a combination of tools (Fairness Indicators, TFDV, and WIT) to detect, diagnose, and take steps toward remediating the underlying problem.

What’s next?
Fairness Indicators is only the first step. We plan to expand vertically by enabling more supported metrics, such as metrics that enable you to evaluate classifiers without thresholds, and horizontally by creating remediation libraries that utilize methods, such as active learning and min-diff. Because we believe it is important to learn through real examples, we hope to ground our work in more case studies to be released over the next few months, as more features become available.

To get started, see the Fairness Indicators GitHub repo. For more information on how to think about fairness evaluation in the context of your use case, see this link.

We would love to partner with you to understand where Fairness Indicators is most useful, and where added functionality would be valuable. Please reach out at [email protected] to provide any feedback on your experience!

Acknowledgements
The core team behind this work includes Christina Greer, Manasi Joshi, Huanming Fang, Shivam Jindal, Karan Shukla, Osman Aka, Sanders Kleinfeld, Alicia Chang, Alex Hanna, and Dan Nanas. We would also like to thank James Wexler, Mahima Pushkarna, Meg Mitchell and Ben Hutchinson for their contributions to the project.

Source: Google AI Blog


Object Detection and Tracking using MediaPipe

Posted by Ming Guang Yong, Product Manager for MediaPipe

MediaPipe in 2019

MediaPipe is a framework for building cross platform multimodal applied ML pipelines that consist of fast ML inference, classic computer vision, and media processing (e.g. video decoding). MediaPipe was open sourced at CVPR in June 2019 as v0.5.0. Since our first open source version, we have released various ML pipeline examples like

In this blog, we will introduce another MediaPipe example: Object Detection and Tracking. We first describe our newly released box tracking solution, then we explain how it can be connected with Object Detection to provide an Object Detection and Tracking system.

Box Tracking in MediaPipe

In MediaPipe v0.6.7.1, we are excited to release a box tracking solution, that has been powering real-time tracking in Motion Stills, YouTube’s privacy blur, and Google Lens for several years and that is leveraging classic computer vision approaches. Pairing tracking with ML inference results in valuable and efficient pipelines. In this blog, we pair box tracking with object detection to create an object detection and tracking pipeline. With tracking, this pipeline offers several advantages over running detection per frame:

  • It provides instance based tracking, i.e. the object ID is maintained across frames.
  • Detection does not have to run every frame. This enables running heavier detection models that are more accurate while keeping the pipeline lightweight and real-time on mobile devices.
  • Object localization is temporally consistent with the help of tracking, meaning less jitter is observable across frames.

Our general box tracking solution consumes image frames from a video or camera stream, and starting box positions with timestamps, indicating 2D regions of interest to track, and computes the tracked box positions for each frame. In this specific use case, the starting box positions come from object detection, but the starting position can also be provided manually by the user or another system. Our solution consists of three main components: a motion analysis component, a flow packager component, and a box tracking component. Each component is encapsulated as a MediaPipe calculator, and the box tracking solution as a whole is represented as a MediaPipe subgraph shown below.

Visualization of Tracking State for Each Box

MediaPipe Box Tracking Subgraph

The MotionAnalysis calculator extracts features (e.g. high-gradient corners) across the image, tracks those features over time, classifies them into foreground and background features, and estimates both local motion vectors and the global motion model. The FlowPackager calculator packs the estimated motion metadata into an efficient format. The BoxTracker calculator takes this motion metadata from the FlowPackager calculator and the position of starting boxes, and tracks the boxes over time. Using solely the motion data (without the need for the RGB frames) produced by the MotionAnalysis calculator, the BoxTracker calculator tracks individual objects or regions while discriminating from others. To track an input region, we first use the motion data corresponding to this region and employ iteratively reweighted least squares (IRLS) fitting a parametric model to the region’s weighted motion vectors. Each region has a tracking state including its prior, mean velocity, set of inlier and outlier feature IDs, and the region centroid. See the figure below for a visualization of the tracking state, with green arrows indicating motion vectors of inliers, and red arrows indicating motion vectors of outliers. Note that by only relying on feature IDs we implicitly capture the region’s appearance, since each feature’s patch intensity stays roughly constant over time. Additionally, by decomposing a region’s motion into that of the camera motion and the individual object motion, we can even track featureless regions.

Visualization of Tracking State for Each Box

An advantage of our architecture is that by separating motion analysis into a dedicated MediaPipe calculator and tracking features over the whole image, we enable great flexibility and constant computation independent of the number of regions tracked! By not having to rely on the RGB frames during tracking, our tracking solution provides the flexibility to cache the metadata across a batch of frame. Caching enables tracking of regions both backwards and forwards in time; or even sync directly to a specified timestamp for tracking with random access.

Object Detection and Tracking

A MediaPipe example graph for object detection and tracking is shown below. It consists of 4 compute nodes: a PacketResampler calculator, an ObjectDetection subgraph released previously in the MediaPipe object detection example, an ObjectTracking subgraph that wraps around the BoxTracking subgraph discussed above, and a Renderer subgraph that draws the visualization.

MediaPipe Example Graph for Object Detection and Tracking. Boxes in purple are subgraphs.

In general, the ObjectDetection subgraph (which performs ML model inference internally) runs only upon request, e.g. at an arbitrary frame rate or triggered by specific signals. More specifically, in this example PacketResampler temporally subsamples the incoming video frames to 0.5 fps before they are passed into ObjectDetection. This frame rate can be configured differently as an option in PacketResampler.

The ObjectTracking subgraph runs in real-time on every incoming frame to track the detected objects. It expands the BoxTracking subgraph described above with additional functionality: when new detections arrive it uses IoU (Intersection over Union) to associate the current tracked objects/boxes with new detections to remove obsolete or duplicated boxes.

A sample result of this object detection and tracking example can be found below. The left image is the result of running object detection per frame. The right image is the result of running object detection and tracking. Note that the result with tracking is much more stable with less temporal jitter. It also maintains object IDs across frames.

Comparison Between Object Detection Per Frame and Object Detection and Tracking

Follow MediaPipe

This is our first Google Developer blog post for MediaPipe. We look forward to publishing new blog posts related to new MediaPipe ML pipeline examples and features. Please follow the MediaPipe tag on the Google Developer blog and Google Developer twitter account (@googledevs)

Acknowledgements

We would like to thank Fan Zhang, Genzhi Ye, Jiuqiang Tang, Jianing Wei, Chuo-Ling Chang, Ming Guang Yong, and Matthias Grundman for building the object detection and tracking solution in MediaPipe and contributing to this blog post.

Lessons Learned from Developing ML for Healthcare



Machine learning (ML) methods are not new in medicine -- traditional techniques, such as decision trees and logistic regression, were commonly used to derive established clinical decision rules (for example, the TIMI Risk Score for estimating patient risk after a coronary event). In recent years, however, there has been a tremendous surge in leveraging ML for a variety of medical applications, such as predicting adverse events from complex medical records, and improving the accuracy of genomic sequencing. In addition to detecting known diseases, ML models can tease out previously unknown signals, such as cardiovascular risk factors and refractive error from retinal fundus photographs.

Beyond developing these models, it’s important to understand how they can be incorporated into medical workflows. Previous research indicates that doctors assisted by ML models can be more accurate than either doctors or models alone in grading diabetic eye disease and diagnosing metastatic breast cancer. Similarly, doctors are able to leverage ML-based tools in an interactive fashion to search for similar medical images, providing further evidence that doctors can work effectively with ML-based assistive tools.

In an effort to improve guidance for research at the intersection of ML and healthcare, we have written a pair of articles, published in Nature Materials and the Journal of the American Medical Association (JAMA). The first is for ML practitioners to better understand how to develop ML solutions for healthcare, and the other is for doctors who desire a better understanding of whether ML could help improve their clinical work.

How to Develop Machine Learning Models for Healthcare
In “How to develop machine learning models for healthcare” (pdf), published in Nature Materials, we discuss the importance of ensuring that the needs specific to the healthcare environment inform the development of ML models for that setting. This should be done throughout the process of developing technologies for healthcare applications, from problem selection, data collection and ML model development to validation and assessment, deployment and monitoring.

The first consideration is how to identify a healthcare problem for which there is both an urgent clinical need and for which predictions based on ML models will provide actionable insight. For example, ML for detecting diabetic eye disease can help alleviate the screening workload in parts of the world where diabetes is prevalent and the number of medical specialists is insufficient. Once the problem has been identified, one must be careful with data curation to ensure that the ground truth labels, or “reference standard”, applied to the data are reliable and accurate. This can be accomplished by validating labels via comparison to expert interpretation of the same data, such as retinal fundus photographs, or through an orthogonal procedure, such as a biopsy to confirm radiologic findings. This is particularly important since a high-quality reference standard is essential both for training useful models and for accurately measuring model performance. Therefore, it is critical that ML practitioners work closely with clinical experts to ensure the rigor of the reference standard used for training and evaluation.

Validation of model performance is also substantially different in healthcare, because the problem of distributional shift can be pronounced. In contrast to typical ML studies where a single random test split is common, the medical field values validation using multiple independent evaluation datasets, each with different patient populations that may exhibit differences in demographics or disease subtypes. Because the specifics depend on the problem, ML practitioners should work closely with clinical experts to design the study, with particular care in ensuring that the model validation and performance metrics are appropriate for the clinical setting.

Integration of the resulting assistive tools also requires thoughtful design to ensure seamless workflow integration, with consideration for measurement of the impact of these tools on diagnostic accuracy and workflow efficiency. Importantly, there is substantial value in prospective study of these tools in real patient care to better understand their real-world impact.

Finally, even after validation and workflow integration, the journey towards deployment is just beginning: regulatory approval and continued monitoring for unexpected error modes or adverse events in real use remains ahead.
Two examples of the translational process of developing, validating, and implementing ML models for healthcare based on our work in detecting diabetic eye disease and metastatic breast cancer.
Empowering Doctors to Better Understand Machine Learning for Healthcare
In “Users’ Guide to the Medical Literature: How to Read Articles that use Machine Learning,” published in JAMA, we summarize key ML concepts to help doctors evaluate ML studies for suitability of inclusion in their workflow. The goal of this article is to demystify ML, to assist doctors who need to use ML systems to understand their basic functionality, when to trust them, and their potential limitations.

The central questions doctors ask when evaluating any study, whether ML or not, remain: Was the reference standard reliable? Was the evaluation unbiased, such as assessing for both false positives and false negatives, and performing a fair comparison with clinicians? Does the evaluation apply to the patient population that I see? How does the ML model help me in taking care of my patients?

In addition to these questions, ML models should also be scrutinized to determine whether the hyperparameters used in their development were tuned on a dataset independent of that used for final model evaluation. This is particularly important, since inappropriate tuning can lead to substantial overestimation of performance, e.g., a sufficiently sophisticated model can be trained to completely memorize the training dataset and generalize poorly to new data. Ensuring that tuning was done appropriately requires being mindful of ambiguities in dataset naming, and in particular, using the terminology with which the audience is most familiar:
The intersection of two fields: ML and healthcare creates ambiguity in the term “validation dataset”. An ML validation set is typically used to refer to the dataset used for hyperparameter tuning, whereas a “clinical” validation set is typically used for final evaluation. To reduce confusion, we have opted to refer to the (ML) validation set as the “tuning” set.
Future outlook
It is an exciting time to work on AI for healthcare. The “bench-to-bedside” path is a long one that requires researchers and experts from multiple disciplines to work together in this translational process. We hope that these two articles will promote mutual understanding of what is important for ML practitioners developing models for healthcare and what is emphasized by doctors evaluating these models, thus driving further collaborations between the fields and towards eventual positive impact on patient care.

Acknowledgements
Key contributors to these projects include Yun Liu, Po-Hsuan Cameron Chen, Jonathan Krause, and Lily Peng. The authors would like to acknowledge Greg Corrado and Avinash Varadarajan for their advice, and the Google Health team for their support.

Source: Google AI Blog


Understanding Transfer Learning for Medical Imaging



As deep neural networks are applied to an increasingly diverse set of domains, transfer learning has emerged as a highly popular technique in developing deep learning models. In transfer learning, the neural network is trained in two stages: 1) pretraining, where the network is generally trained on a large-scale benchmark dataset representing a wide diversity of labels/categories (e.g., ImageNet); and 2) fine-tuning, where the pretrained network is further trained on the specific target task of interest, which may have fewer labeled examples than the pretraining dataset. The pretraining step helps the network learn general features that can be reused on the target task.

This kind of two-stage paradigm has become extremely popular in many settings, and particularly so in medical imaging. In the context of transfer learning, standard architectures designed for ImageNet with corresponding pretrained weights are fine-tuned on medical tasks ranging from interpreting chest x-rays and identifying eye diseases, to early detection of Alzheimer’s disease. Despite its widespread use, however, the precise effects of transfer learning are not yet well understood. While recent work challenges many common assumptions, including the effects on performance improvement, contribution of the underlying architecture and impact of pretraining dataset type and size, these results are all in the natural image setting, and leave many questions open for specialized domains, such as medical images.

In our NeurIPS 2019 paper, “Transfusion: Understanding Transfer Learning for Medical Imaging,” we investigate these central questions for transfer learning in medical imaging tasks. Through both a detailed performance evaluation and analysis of neural network hidden representations, we uncover many surprising conclusions, such as the limited benefits of transfer learning for performance on the tested medical imaging tasks, a detailed characterization of how representations evolve through the training process across different models and hidden layers, and feature independent benefits of transfer learning for convergence speed.

Performance Evaluation
We first performed a thorough study on the effect of transfer learning on model performance. We compared models trained from random initialization and applied directly on tasks to those pretrained on ImageNet that leverage transfer learning for the same tasks. We looked at two large scale medical imaging tasks — diagnosing diabetic retinopathy from fundus photographs and identifying five different diseases from chest x-rays. We evaluated various neural network architectures including both standard architectures popularly used for medical imaging (ResNet50, Inception-v3) as well as a family of simple, lightweight convolutional neural networks that consist of four or five layers of the standard convolution-batchnorm-ReLU progression, or CBRs.

The results from evaluating all of these models on the different tasks with and without transfer learning give us four main takeaways:
  • Surprisingly, transfer learning does not significantly affect performance on medical imaging tasks, with models trained from scratch performing nearly as well as standard ImageNet transferred models.
  • On the medical imaging tasks, the much smaller CBR models perform at a level comparable to the standard ImageNet architectures.
  • As the CBR models are much smaller and shallower than the standard ImageNet models, they perform much worse on ImageNet classification, highlighting that ImageNet performance is not indicative of performance on medical tasks.
  • The two medical tasks are much smaller in size than ImageNet (~200k vs ~1.2m training images), but in the very small data regime, there may only be a few thousand training examples. We evaluated transfer learning in this very small data regime, finding that while there was a larger gap in performance between transfer and training from scratch for large models (ResNet) this was not true for smaller models (CBRs), suggesting that the large models designed for ImageNet might be too overparameterized for the very small data regime.
Representation Analysis
We next study the degree to which transfer learning affects the kinds of features and representations learned by the neural networks. Given the similar performance, does transfer learning result in different representations from random initialization? Is knowledge from the pretraining step reused, and if so, where? To find answers to these questions, this study analyzes and compares the hidden representations (i.e., representations learned in the latent layers of the network) in the different neural networks trained to solve these tasks. This quantitative analysis can be challenging, due to the complexity and lack of alignment in different hidden layers. But a recent method, singular vector canonical correlation analysis (SVCCA; code and tutorials), based on canonical correlation analysis (CCA), helps overcome these challenges, and can be used to calculate a similarity score between a pair of hidden representations.

Similarity scores are computed for some of the hidden representations from the top latent layers of the networks (closer to the output) between networks trained from random initialization and networks trained from pretrained ImageNet weights. As a baseline, we also compute similarity scores of representations learned from different random initializations. For large models, representations learned from random initialization are much more similar to each other than those learned from transfer learning. For smaller models, there is greater overlap between representation similarity scores.
Representation similarity scores between networks trained from random initialization and networks trained from pretrained ImageNet weights (orange), and baseline similarity scores of representations trained from two different random initializations (blue). Higher values indicate greater similarity. For larger models, representations learned from random initialization are much more similar to each other than those learned through transfer. This is not the case for smaller models.
The reason for this difference between large and small models becomes clear with further investigation into the hidden representations. Large models change less through training, even from random initialization. We perform multiple experiments that illustrate this, from simple filter visualizations to tracking changes between different layers through fine-tuning.

When we combine the results of all the experiments from the paper, we can assemble a table summarizing how much representations change through training on the medical task across (i) transfer learning, (ii) model size and (iii) lower/higher layers.
Effects on Convergence: Feature Independent Benefits and Hybrid Approaches
One consistent effect of transfer learning was a significant speedup in the time taken for the model to converge. But having seen the mixed results for feature reuse from our representational study, we looked into whether there were other properties of the pretrained weights that might contribute to this speedup. Surprisingly, we found a feature independent benefit of pretraining — the weight scaling.

We initialized the weights of the neural network as independent and identically distributed (iid), just like random initialization, but using the mean and variance of the pretrained weights. We called this initialization the Mean Var Init, which keeps the pretrained weight scaling but destroys all the features. This Mean Var Init offered significant speedups over random initialization across model architectures and tasks, suggesting that the pretraining process of transfer learning also helps with good weight conditioning.
Filter visualization of weights initialized according to pretrained ImageNet weights, Random Init, and Mean Var Init. Only the ImageNet Init filters have pretrained (Gabor-like) structure, as Rand Init and Mean Var weights are iid.
Recall that our earlier experiments suggested that feature reuse primarily occurs in the lowest layers. To understand this, we performed weight transfusion experiments, where only a subset of the pretrained weights (corresponding to a contiguous set of layers) are transferred, with the remainder of weights being randomly initialized. Comparing convergence speeds of these transfused networks with full transfer learning further supports the conclusion that feature reuse is primarily happening in the lowest layers.
Learning curves comparing the convergence speed with AUC on the test set. Using only the scaling of the pretrained weights (Mean Var Init) helps with convergence speed. The figures compare the standard transfer learning and the Mean Var initialization scheme to training from random initialization.
This suggests hybrid approaches to transfer learning, where instead of reusing the full neural network architecture, we can recycle its lowest layers and redesign the upper layers to better suit the target task. This gives us most of the benefits of transfer learning while further enabling flexible model design. In the Figure below, we show the effect of reusing pretrained weights up to Block2 in Resnet50, halving the remainder of the channels, initializing those layers randomly, and then training end-to-end. This matches the performance and convergence of full transfer learning.
Hybrid approaches to transfer learning on Resnet50 (left) and CBR models (right) — reusing a subset of the weights and slimming the remainder of the network (Slim), and using mathematically synthesized Gabors for conv1 (Synthetic Gabor).
The figure above also shows the results of an extreme version of this partial reuse, transferring only the very first convolutional layer with mathematically synthesized Gabor filters (pictured below). Using just these (synthetic) weights offers significant speedups, and hints at many other creative hybrid approaches.
Synthetic Gabor filters used to initialize the first layer if neural networks in some of the experiments in this paper. The Gabor filters are generated as grayscale images and repeated across the RGB channels. Left: Low frequencies. Right: High frequencies.
Conclusion and Open Questions
Transfer learning is a central technique for many domains. In this paper we provide insights on some of its fundamental properties in the medical imaging context, studying performance, feature reuse, the effect of different architectures, convergence and hybrid approaches. Many interesting open questions remain: How much of the original task has the model forgotten? Why do large models change less? Can we get further gains matching higher order moments of pretrained weight statistics? Are the results similar for other tasks, such as segmentation? We look forward to tackling these questions in future work!

Acknowledgements
Special thanks to Samy Bengio and Jon Kleinberg, who are co-authors on this work. Thanks also to Geoffrey Hinton for helpful feedback.

Source: Google AI Blog


Developing Deep Learning Models for Chest X-rays with Adjudicated Image Labels



With millions of diagnostic examinations performed annually, chest X-rays are an important and accessible clinical imaging tool for the detection of many diseases. However, their usefulness can be limited by challenges in interpretation, which requires rapid and thorough evaluation of a two-dimensional image depicting complex, three-dimensional organs and disease processes. Indeed, early-stage lung cancers or pneumothoraces (collapsed lungs) can be missed on chest X-rays, leading to serious adverse outcomes for patients.

Advances in machine learning (ML) present an exciting opportunity to create new tools to help experts interpret medical images. Recent efforts have shown promise in improving lung cancer detection in radiology, prostate cancer grading in pathology, and differential diagnoses in dermatology. For chest X-ray images in particular, large, de-identified public image sets are available to researchers across disciplines, and have facilitated several valuable efforts to develop deep learning models for X-ray interpretation. However, obtaining accurate clinical labels for the very large image sets needed for deep learning can be difficult. Most efforts have either applied rule-based natural language processing (NLP) to radiology reports or relied on image review by individual readers, both of which may introduce inconsistencies or errors that can be especially problematic during model evaluation. Another challenge involves assembling datasets that represent an adequately diverse spectrum of cases (i.e., ensuring inclusion of both “hard” cases and “easy” cases that represent the full spectrum of disease presentation). Finally, some chest X-ray findings are non-specific and depend on clinical information about the patient to fully understand their significance. As such, establishing labels that are clinically meaningful and have consistent definitions can be a challenging component of developing machine learning models that use only the image as input. Without standardized and clinically meaningful datasets as well as rigorous reference standard methods, successful application of ML to interpretation of chest X-rays will be hindered.

To help address these issues, we recently published “Chest Radiograph Interpretation with Deep Learning Models: Assessment with Radiologist-adjudicated Reference Standards and Population-adjusted Evaluation” in the journal Radiology. In this study we developed deep learning models to classify four clinically important findings on chest X-rays — pneumothorax, nodules and masses, fractures, and airspace opacities. These target findings were selected in consultation with radiologists and clinical colleagues, so as to focus on conditions that are both critical for patient care and for which chest X-ray images alone are an important and accessible first-line imaging study. Selection of these findings also allowed model evaluation using only de-identified images without additional clinical data.

Models were evaluated using thousands of held-out images from each dataset for which we collected high-quality labels using a panel-based adjudication process among board-certified radiologists. Four separate radiologists also independently reviewed the held-out images in order to compare radiologist accuracy to that of the deep learning models (using the panel-based image labels as the reference standard). For all four findings and across both datasets, the deep learning models demonstrated radiologist-level performance. We are sharing the adjudicated labels for the publicly available data here to facilitate additional research.

Data Overview
This work leveraged over 600,000 images sourced from two de-identified datasets. The first dataset was developed in collaboration with co-authors at the Apollo Hospitals, and consists of a diverse set of chest X-rays obtained over several years from multiple locations across the Apollo Hospitals network. The second dataset is the publicly available ChestX-ray14 image set released by the National Institutes of Health (NIH). This second dataset has served as an important resource for many machine learning efforts, yet has limitations stemming from issues with the accuracy and clinical interpretation of the currently available labels.
Chest X-ray depicting an upper left lobe pneumothorax identified by the model and the adjudication panel, but missed by the individual radiologist readers. Left: The original image. Right: The same image with the most important regions for the model prediction highlighted in orange.
Training Set Labels Using Deep Learning and Visual Image Review
For very large datasets consisting of hundreds of thousands of images, such as those needed to train highly accurate deep learning models, it is impractical to manually assign image labels. As such, we developed a separate, text-based deep learning model to extract image labels using the de-identified radiology reports associated with each X-ray. This NLP model was then applied to provide labels for over 560,000 images from the Apollo Hospitals dataset used for training the computer vision models.

To reduce noise from any errors introduced by the text-based label extraction and also to provide the relevant labels for a substantial number of the ChestX-ray14 images, approximately 37,000 images across the two datasets were visually reviewed by radiologists. These were separate from the NLP-based labels and helped to ensure high quality labels across such a large, diverse set of training images.

Creating and Sharing Improved Reference Standard Labels
To generate high-quality reference standard labels for model evaluation, we utilized a panel-based adjudication process, whereby three radiologists reviewed all final tune and test set images and resolved disagreements through discussion. This often allowed difficult findings that were initially only detected by a single radiologist to be identified and documented appropriately. To reduce the risk of bias based on any individual radiologist’s personality or seniority, the discussions took place anonymously via an online discussion and adjudication system.

Because the lack of available adjudicated labels was a significant initial barrier to our work, we are sharing with the research community all of the adjudicated labels for the publicly available ChestX-ray14 dataset, including 2,412 training/validation set images and 1,962 test set images (4,374 images in total). We hope that these labels will facilitate future machine learning efforts and enable better apples-to-apples comparisons between machine learning models for chest X-ray interpretation.

Future Outlook
This work presents several contributions: (1) releasing adjudicated labels for images from a publicly available dataset; (2) a method to scale accurate labeling of training data using a text-based deep learning model; (3) evaluation using a diverse set of images with expert-adjudicated reference standard labels; and ultimately (4) radiologist-level performance of deep learning models for clinically important findings on chest X-rays.

However, in regards to model performance, achieving expert-level accuracy on average is just a part of the story. Even though overall accuracy for the deep learning models was consistently similar to that of radiologists for any given finding, performance for both varied across datasets. For example, the sensitivity for detecting pneumothorax among radiologists was approximately 79% for the ChestX-ray14 images, but was only 52% for the same radiologists on the other dataset, suggesting a more difficult collection cases in the latter. This highlights the importance of validating deep learning tools on multiple, diverse datasets and eventually across the patient populations and clinical settings in which any model is intended to be used.

The performance differences between datasets also emphasize the need for standardized evaluation image sets with accurate reference standards in order to allow comparison across studies. For example, if two different models for the same finding were evaluated using different datasets, comparing performance would be of minimal value without knowing additional details such as the case mix, model error modes, or radiologist performance on the same cases.

Finally, the model often identified findings that were consistently missed by radiologists, and vice versa. As such, strategies that combine the unique “skills” of both the deep learning systems and human experts are likely to hold the most promise for realizing the potential of AI applications in medical image interpretation.

Acknowledgements
Key contributors to this project at Google include Sid Mittal, Gavin Duggan, Anna Majkowska, Scott McKinney, Andrew Sellergren, David Steiner, Krish Eswaran, Po-Hsuan Cameron Chen, Yun Liu, Shravya Shetty, and Daniel Tse. Significant contributions and input were also made by radiologist collaborators Joshua Reicher, Alexander Ding, and Sreenivasa Raju Kalidindi. The authors would also like to acknowledge many members of the Google Health radiology team including Jonny Wong, Diego Ardila, Zvika Ben-Haim, Rory Sayres, Shahar Jamshy, Shabir Adeel, Mikhail Fomitchev, Akinori Mitani, Quang Duong, William Chen and Sahar Kazemzadeh. Sincere appreciation also goes to the many radiologists who enabled this work through their expert image interpretation efforts throughout the project.

Source: Google AI Blog