Tag Archives: deep learning

Improved Grading of Prostate Cancer Using Deep Learning

Approximately 1 in 9 men in the United States will develop prostate cancer in their lifetime, making it the most common cancer in males. Despite being common, prostate cancers are frequently non-aggressive, making it challenging to determine if the cancer poses a significant enough risk to the patient to warrant treatment such as surgical removal of the prostate (prostatectomy) or radiation therapy. A key factor that helps in the “risk stratification” of prostate cancer patients is the Gleason grade, which classifies the cancer cells based on how closely they resemble normal prostate glands when viewed on a slide under a microscope.

However, despite its widely recognized clinical importance, Gleason grading of prostate cancer is complex and subjective, as evidenced by studies reporting inter-pathologist disagreements ranging from 30-53% [1][2]. Furthermore, there are not enough speciality trained pathologists to meet the global demand for prostate cancer pathology, especially outside the United States. Recent guidelines also recommend that pathologists report the percentage of tumor of different Gleason patterns in their final report, which adds to the workload and is yet another subjective challenge for the pathologist [3]. Overall, these issues suggest an opportunity to improve the diagnosis and clinical management of prostate cancer using deep learning–based models, similar to how Google and others used such techniques to demonstrate the potential to improve metastatic breast cancer detection.

In “Development and Validation of a Deep Learning Algorithm for Improving Gleason Scoring of Prostate Cancer”, we explore whether deep learning could improve the accuracy and objectivity of Gleason grading of prostate cancer in prostatectomy specimens. We developed a deep learning system (DLS) that mirrors a pathologist’s workflow by first categorizing each region in a slide into a Gleason pattern, with lower patterns corresponding to tumors that more closely resemble normal prostate glands. The DLS then summarizes an overall Gleason grade group based on the two most common Gleason patterns present. The higher the grade group, the greater the risk of further cancer progression and the more likely the patient is to benefit from treatment.
Visual examples of Gleason patterns, which are used in the Gleason system for grading prostate cancer. Individual cancer patches are assigned a Gleason pattern based on how closely the cancer resembles normal prostate tissue, with lower numbers corresponding to more well differentiated tumors. Image Source: National Institutes of Health.
To develop and validate the DLS, we collected de-identified images of prostatectomy samples which contain a greater amount and diversity of prostate cancer than needle core biopsies, even though the latter is the more common clinical procedure. On the training data, a cohort of 32 pathologists provided detailed annotations of Gleason patterns (resulting in over 112 million annotated image patches) and an overall Gleason grade group for each image. To overcome the previously referenced variability in Gleason grading, each slide in the validation set was independently graded by 3 to 5 general pathologists (selected from a cohort of 29 pathologists) and had a final Gleason grade assigned by a genitourinary-specialist pathologist to obtain the ground-truth label for that slide.

In the paper, we show that our DLS achieved an overall accuracy of 70%, compared to an average accuracy of 61% achieved by US board-certified general pathologists in our study. Of 10 high-performing individual general pathologists who graded every slide in the validation set, the DLS was more accurate than 8. The DLS was also more accurate than the average pathologist at Gleason pattern quantitation. These improvements in Gleason grading translated into better clinical risk stratification: the DLS better identified patients at higher risk for disease recurrence after surgery than the average general pathologist, potentially enabling doctors to use this information to better match patients to therapy.
Comparison of scoring performance of the DLS with pathologists. a: Accuracy of the DLS (in red) compared with the mean accuracy among a cohort-of-29 pathologists (in green). Error bars indicate 95% confidence intervals. b: Comparison of risk stratification provided by the DLS, the cohort-of-29 pathologists, and the genitourinary specialist pathologists. Patients are divided into low and high risk groups based on their Gleason grade group, where a larger separation between the Kaplan-Meier curves of these risk groups indicates better stratification.
We also found that the DLS was able to characterize tissue morphology that appeared to lie at the cusp of two Gleason patterns, which is one reason for the disagreements in Gleason grading observed between pathologists, suggesting the possibility of creating finer grained “precision grading” of prostate cancer. While the clinical significance of these intermediate patterns (e.g. Gleason pattern 3.3 or 3.7) is not known, the increased precision of the DLS will enable further research into this interesting question.
Assessing the region-level classification of the DLS. a: Annotations from 3 pathologists compared to DLS predictions. The pathologists show general concordance on the location and the extent of tumor areas, but poor agreement in classifying Gleason patterns. The DLS’s precision Gleason pattern for each region is represented by interpolating between the DLS’s prediction patterns for Gleason patterns 3 (green), 4 (yellow), and 5 (red). b: DLS prediction
patterns compared to the distribution of pathologists’ Gleason pattern classifications on 41 million annotated image patches from the test dataset. On patches where pathologists are discordant, where the tissue is more likely to be on the cusp of two patterns, the DLS reflects this ambiguity in it's prediction scores.
While these initial results are encouraging, there is much more work to be done before systems like our DLS can be used to improve the care of prostate cancer patients. First, the accuracy of the model can be further improved with additional training data and should be validated on independent cohorts containing a larger number and more diverse group of patients. In addition, we are actively working on refining our DLS system to work on diagnostic needle core biopsies, which occur prior to the decision to undergo surgery and where Gleason grading therefore has a significantly greater impact on clinical decision-making. Further work will be needed to assess how to best integrate our DLS into the pathologist’s diagnostic workflow and the impact of such artificial-intelligence based assistance on the overall efficiency, accuracy, and prognostic ability of Gleason grading in clinical practice. Nonetheless, we are excited about the potential of technologies like this to significantly improve cancer diagnostics and patient care.

This work involved the efforts of a multidisciplinary team of software engineers, researchers, clinicians and logistics support staff. Key contributors to this project include Kunal Nagpal, Davis Foote, Yun Liu, Po-Hsuan (Cameron) Chen, Ellery Wulczyn, Fraser Tan, Niels Olson, Jenny L. Smith, Arash Mohtashamian, James H. Wren, Greg S. Corrado, Robert MacDonald, Lily H. Peng, Mahul B. Amin, Andrew J. Evans, Ankur R. Sangoi, Craig H. Mermel, Jason D. Hipp and Martin C. Stumpe. We would also like to thank Tim Hesterberg, Michael Howell, David Miller, Alvin Rajkomar, Benny Ayalew, Robert Nagle, Melissa Moran, Krishna Gadepalli, Aleksey Boyko, and Christopher Gammage. Lastly, this work would not have been possible without the aid of the pathologists who annotated data for this study.

  1. Interobserver Variability in Histologic Evaluation of Radical Prostatectomy Between Central and Local Pathologists: Findings of TAX 3501 Multinational Clinical Trial, Netto, G. J., Eisenberger, M., Epstein, J. I. & TAX 3501 Trial Investigators, Urology 77, 1155–1160 (2011).
  2. Phase 3 Study of Adjuvant Radiotherapy Versus Wait and See in pT3 Prostate Cancer: Impact of Pathology Review on Analysis, Bottke, D., Golz, R., Störkel, S., Hinke, A., Siegmann, A., Hertle, L., Miller, K., Hinkelbein, W., Wiegel, T., Eur. Urol. 64, 193–198 (2013).
  3. Utility of Quantitative Gleason Grading in Prostate Biopsies and Prostatectomy Specimens, Sauter, G. Steurer, S., Clauditz, T. S., Krech, T., Wittmer, C., Lutz, F., Lennartz, M., Janssen, T., Hakimi, N., Simon, R., von Petersdorff-Campen, M., Jacobsen, F., von Loga, K., Wilczak, W., Minner, S., Tsourlakis, M. C., Chirico, V., Haese, A., Heinzer, H., Beyer, B., Graefen, M., Michl, U., Salomon, G., Steuber, T., Budäus, L. H., Hekeler, E., Malsy-Mink, J., Kutzera, S., Fraune, C., Göbel, C., Huland, H., Schlomm, T., Clinical Eur. Urol. 69, 592–598 (2016).

Source: Google AI Blog

Accurate Online Speaker Diarization with Supervised Learning

Speaker diarization, the process of partitioning an audio stream with multiple people into homogeneous segments associated with each individual, is an important part of speech recognition systems. By solving the problem of “who spoke when”, speaker diarization has applications in many important scenarios, such as understanding medical conversations, video captioning and more. However, training these systems with supervised learning methods is challenging — unlike standard supervised classification tasks, a robust diarization model requires the ability to associate new individuals with distinct speech segments that weren't involved in training. Importantly, this limits the quality of both online and offline diarization systems. Online systems usually suffer more, since they require diarization results in real time.
Online speaker diarization on streaming audio input. Different colors in the bottom axis indicate different speakers.
In “Fully Supervised Speaker Diarization”, we describe a new model that seeks to make use of supervised speaker labels in a more effective manner. Here “fully” implies that all components in the speaker diarization system, including the estimation of the number of speakers, are trained in supervised ways, so that they can benefit from increasing the amount of labeled data available. On the NIST SRE 2000 CALLHOME benchmark, our diarization error rate (DER) is as low as 7.6%, compared to 8.8% DER from our previous clustering-based method, and 9.9% from deep neural network embedding methods. Moreover, our method achieves this lower error rate based on online decoding, making it specifically suitable for real-time applications. As such we are open sourcing the core algorithms in our paper to accelerate more research along this direction.

Clustering versus Interleaved-state RNN
Modern speaker diarization systems are usually based on clustering algorithms such as k-means or spectral clustering. Since these clustering methods are unsupervised, they could not make good use of the supervised speaker labels available in data. Moreover, online clustering algorithms usually have worse quality in real-time diarization applications with streaming audio inputs. The key difference between our model and common clustering algorithms is that in our method, all speakers’ embeddings are modeled by a parameter-sharing recurrent neural network (RNN), and we distinguish different speakers using different RNN states, interleaved in the time domain.

To understand how this works, consider the example below in which there are four possible speakers: blue, yellow, pink and green (this is arbitrary, and in fact there may be more — our model uses the Chinese restaurant process to accommodate the unknown number of speakers). Each speaker starts with its own RNN instance (with a common initial state shared among all speakers) and keeps updating the RNN state given the new embeddings from this speaker. In the example below, the blue speaker keeps updating its RNN state until a different speaker, yellow, comes in. If blue speaks again later, it resumes updating its RNN state. (This is just one of the possibilities for speech segment y7 in the figure below. If new speaker green enters, it will start with a new RNN instance.)
The generative process of our model. Colors indicate labels for speaker segments.
Representing speakers as RNN states enables us to learn the high-level knowledge shared across different speakers and utterances using RNN parameters, and this promises the usefulness of more labeled data. In contrast, common clustering algorithms almost always work with each single utterance independently, making it difficult to benefit from a large amount of labeled data.

The upshot of all this is that given time-stamped speaker labels (i.e. we know who spoke when), we can train the model with standard stochastic gradient descent algorithms. A trained model can be used for speaker diarization on new utterances from unheard speakers. Furthermore, the use of online decoding makes it more suitable for latency-sensitive applications.

Future Work
Although we've already achieved impressive diarization performance with this system, there are still many exciting directions we are currently exploring. First, we are refining our model so it can easily integrate contextual information to perform offline decoding. This will likely further reduce the DER, which is more useful for latency-insensitive applications. Second, we would like to model acoustic features directly instead of using d-vectors. In this way, the entire speaker diarization system can be trained in an end-to-end way.

To learn more about this work, please see our paper. To download the core algorithm of this system, please visit the Github page.

This work was done as a close collaboration between Google AI and Speech & Assistant teams. Contributors include Aonan Zhang (intern), Quan Wang, Zhengyao Zhu and Chong Wang.

Source: Google AI Blog

Open Sourcing BERT: State-of-the-Art Pre-training for Natural Language Processing

One of the biggest challenges in natural language processing (NLP) is the shortage of training data. Because NLP is a diversified field with many distinct tasks, most task-specific datasets contain only a few thousand or a few hundred thousand human-labeled training examples. However, modern deep learning-based NLP models see benefits from much larger amounts of data, improving when trained on millions, or billions, of annotated training examples. To help close this gap in data, researchers have developed a variety of techniques for training general purpose language representation models using the enormous amount of unannotated text on the web (known as pre-training). The pre-trained model can then be fine-tuned on small-data NLP tasks like question answering and sentiment analysis, resulting in substantial accuracy improvements compared to training on these datasets from scratch.

This week, we open sourced a new technique for NLP pre-training called Bidirectional Encoder Representations from Transformers, or BERT. With this release, anyone in the world can train their own state-of-the-art question answering system (or a variety of other models) in about 30 minutes on a single Cloud TPU, or in a few hours using a single GPU. The release includes source code built on top of TensorFlow and a number of pre-trained language representation models. In our associated paper, we demonstrate state-of-the-art results on 11 NLP tasks, including the very competitive Stanford Question Answering Dataset (SQuAD v1.1).

What Makes BERT Different?
BERT builds upon recent work in pre-training contextual representations — including Semi-supervised Sequence Learning, Generative Pre-Training, ELMo, and ULMFit. However, unlike these previous models, BERT is the first deeply bidirectional, unsupervised language representation, pre-trained using only a plain text corpus (in this case, Wikipedia).

Why does this matter? Pre-trained representations can either be context-free or contextual, and contextual representations can further be unidirectional or bidirectional. Context-free models such as word2vec or GloVe generate a single word embedding representation for each word in the vocabulary. For example, the word “bank” would have the same context-free representation in “bank account” and “bank of the river.” Contextual models instead generate a representation of each word that is based on the other words in the sentence. For example, in the sentence “I accessed the bank account,” a unidirectional contextual model would represent “bank” based on “I accessed the” but not “account.” However, BERT represents “bank” using both its previous and next context — “I accessed the ... account” — starting from the very bottom of a deep neural network, making it deeply bidirectional.

A visualization of BERT’s neural network architecture compared to previous state-of-the-art contextual pre-training methods is shown below. The arrows indicate the information flow from one layer to the next. The green boxes at the top indicate the final contextualized representation of each input word:
BERT is deeply bidirectional, OpenAI GPT is unidirectional, and ELMo is shallowly bidirectional.
The Strength of Bidirectionality
If bidirectionality is so powerful, why hasn’t it been done before? To understand why, consider that unidirectional models are efficiently trained by predicting each word conditioned on the previous words in the sentence. However, it is not possible to train bidirectional models by simply conditioning each word on its previous and next words, since this would allow the word that’s being predicted to indirectly “see itself” in a multi-layer model.

To solve this problem, we use the straightforward technique of masking out some of the words in the input and then condition each word bidirectionally to predict the masked words. For example:
While this idea has been around for a very long time, BERT is the first time it was successfully used to pre-train a deep neural network.

BERT also learns to model relationships between sentences by pre-training on a very simple task that can be generated from any text corpus: Given two sentences A and B, is B the actual next sentence that comes after A in the corpus, or just a random sentence? For example:
Training with Cloud TPUs
Everything that we’ve described so far might seem fairly straightforward, so what’s the missing piece that made it work so well? Cloud TPUs. Cloud TPUs gave us the freedom to quickly experiment, debug, and tweak our models, which was critical in allowing us to move beyond existing pre-training techniques. The Transformer model architecture, developed by researchers at Google in 2017, also gave us the foundation we needed to make BERT successful. The Transformer is implemented in our open source release, as well as the tensor2tensor library.

Results with BERT
To evaluate performance, we compared BERT to other state-of-the-art NLP systems. Importantly, BERT achieved all of its results with almost no task-specific changes to the neural network architecture. On SQuAD v1.1, BERT achieves 93.2% F1 score (a measure of accuracy), surpassing the previous state-of-the-art score of 91.6% and human-level score of 91.2%:
BERT also improves the state-of-the-art by 7.6% absolute on the very challenging GLUE benchmark, a set of 9 diverse Natural Language Understanding (NLU) tasks. The amount of human-labeled training data in these tasks ranges from 2,500 examples to 400,000 examples, and BERT substantially improves upon the state-of-the-art accuracy on all of them:
Making BERT Work for You
The models that we are releasing can be fine-tuned on a wide variety of NLP tasks in a few hours or less. The open source release also includes code to run pre-training, although we believe the majority of NLP researchers who use BERT will never need to pre-train their own models from scratch. The BERT models that we are releasing today are English-only, but we hope to release models which have been pre-trained on a variety of languages in the near future.

The open source TensorFlow implementation and pointers to pre-trained BERT models can be found at http://goo.gl/language/bert. Alternatively, you can get started using BERT through Colab with the notebook “BERT FineTuning with Cloud TPUs.”

You can also read our paper "BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding" for more details.

Source: Google AI Blog

Acoustic Detection of Humpback Whales Using a Convolutional Neural Network

Over the last several years, Google AI Perception teams have developed techniques for audio event analysis that have been applied on YouTube for non-speech captions, video categorizations, and indexing. Furthermore, we have published the AudioSet evaluation set and open-sourced some model code in order to further spur research in the community. Recently, we’ve become increasingly aware that many conservation organizations were collecting large quantities of acoustic data, and wondered whether it might be possible to apply these same technologies to that data in order to assist wildlife monitoring and conservation.

As part of our AI for Social Good program, and in partnership with the Pacific Islands Fisheries Science Center of the U.S. National Oceanic and Atmospheric Administration (NOAA), we developed algorithms to identify humpback whale calls in 15 years of underwater recordings from a number of locations in the Pacific. The results of this research provide new and important information about humpback whale presence, seasonality, daily calling behavior, and population structure. This is especially important in remote, uninhabited islands, about which scientists have had no information until now. Additionally, because the dataset spans a large period of time, knowing when and where humpback whales are calling will provide information on whether or not the animals have changed their distribution over the years, especially in relation to increasing human ocean activity. That information will be a key ingredient for effective mitigation of anthropogenic impacts on humpback whales.
HARP deployment locations. Green: sites with currently active recorders. Red: previous recording sites.
Passive Acoustic Monitoring and the NOAA HARP Dataset
Passive acoustic monitoring is the process of listening to marine mammals with underwater microphones called hydrophones, which can be used to record signals so that detection, classification, and localization tasks can be done offline. This has some advantages over ship-based visual surveys, including the ability to detect submerged animals, longer detection ranges and longer monitoring periods. Since 2005, NOAA has collected recordings from ocean-bottom hydrophones at 12 sites in the Pacific Island region, a winter breeding and calving destination for certain populations of humpback whales.

The data was recorded on devices called high-frequency acoustic recording packages, or HARPs (Wiggins and Hildebrand, 2007; full text PDF). In total, NOAA provided about 15 years of audio, or 9.2 terabytes after decimation from 200 kHz to 10kHz. (Since most of the sound energy in humpback vocalizations is in the 100Hz-2000Hz range, little is lost in using the lower sample rate.)

From a research perspective, identifying species of interest in such large volumes of data is an important first stage that provides input for higher-level population abundance, behavioral or oceanographic analyses. However, manually marking humpback whale calls, even with the aid of currently available computer-assisted methods, is extremely time-consuming.

Supervised Learning: Optimizing an Image Model for Humpback Detection
We made the common choice of treating audio event detection as an image classification problem, where the image is a spectrogram — a histogram of sound power plotted on time-frequency axes.
Example spectrograms of audio events found in the dataset, with time on the x-axis and frequency on the y-axis. Left: a humpback whale call (in particular, a tonal unit), Center: narrow-band noise from an unknown source, Right: hard disk noise from the HARP
This is a good representation for an image classifier, whose goal is to discriminate, because the different spectra (frequency decompositions) and time variations thereof (which are characteristic of distinct sound types) are represented in the spectrogram as visually dissimilar patterns. For the image model itself, we used ResNet-50, a convolutional neural network architecture typically used for image classification that has shown success at classifying non-speech audio. This is a supervised learning setup, where only manually labeled data could be used for training (0.2% of the entire dataset — in the next section, we describe an approach that makes use of the unlabeled data.)

The process of going from waveform to spectrogram involves choices of parameters and gain-scaling functions. Common default choices (one of which was logarithmic compression) were a good starting point, but some domain-specific tuning was needed to optimize the detection of whale calls. Humpback vocalizations are varied, but sustained, frequency-modulated, tonal units occur frequently in time. You can listen to an example below:

If the frequency didn't vary at all, a tonal unit would appear in the spectrogram as a horizontal bar. Since the calls are frequency-modulated, we actually see arcs instead of bars, but parts of the arcs are close to horizontal.

A challenge particular to this dataset was narrow-band noise, most often caused by nearby boats and the equipment itself. In a spectrogram it appears as horizontal lines, and early versions of the model would confuse it with humpback calls. This motivated us to try per-channel energy normalization (PCEN), which allows the suppression of stationary, narrow-band noise. This proved to be critical, providing a 24% reduction in error rate of whale call detection.
Spectrograms of the same 5-unit excerpt from humpback whale song beginning at 0:06 in the above recording. Top: PCEN. Bottom: log of squared magnitude. The dark blue horizontal bar along the bottom under log compression has become much lighter relative to the whale call when using PCEN
Aside from PCEN, averaging predictions over a longer period of time led to much better precision. This same effect happens for general audio event detection, but for humpback calls the increase in precision was surprisingly large. A likely explanation is that the vocalizations in our dataset are mainly in the context of whale song, a structured sequence of units than can last over 20 minutes. At the end of one unit in a song, there is a good chance another unit begins within two seconds. The input to the image model covers a short time window, but because the song is so long, model outputs from more distant time windows give extra information useful for making the correct prediction for the current time window.

Overall, evaluating on our test set of 75-second audio clips, the model identifies whether a clip contains humpback calls at over 90% precision and 90% recall. However, one should interpret these results with care; training and test data come from similar equipment and environmental conditions. That said, preliminary checks against some non-NOAA sources look promising.

Unsupervised Learning: Representation for Finding Similar Song Units
A different way to approach the question, "Where are all the humpback sounds in this data?", is to start with several examples of humpback sound and, for each of these, find more in the dataset that are similar to that example. The definition of similar here can be learned by the same ResNet we used when this was framed as a supervised problem. There, we used the labels to learn a classifier on top of the ResNet output. Here, we encourage a pair of ResNet output vectors to be close in Euclidean distance when the corresponding audio examples are close in time. With that distance function, we can retrieve many more examples of audio similar to a given one. In the future, this may be useful input for a classifier that distinguishes different humpback unit types from each other.

To learn the distance function, we used a method described in "Unsupervised Learning of Semantic Audio Representations", based on the idea that closeness in time is related to closeness in meaning. It randomly samples triplets, where each triplet is defined to consist of an anchor, a positive, and a negative. The positive and the anchor are sampled so that they start around the same time. An example of a triplet in our application would be a humpback unit (anchor), a probable repeat of the same unit by the same whale (positive) and background noise from some other month (negative). Passing the 3 samples through the ResNet (with tied weights) represents them as 3 vectors. Minimizing a loss that forces the anchor-negative distance to exceed the anchor-positive distance by a margin learns a distance function faithful to semantic similarity.

Principal component analysis (PCA) on a sample of labeled points lets us visualize the results. Separation between humpback and non-humpback is apparent. Explore for yourself using the TensorFlow Embedding Projector. Try changing Color by to each of class_label and site. Also, try changing PCA to t-SNE in the projector for a visualization that prioritizes preserving relative distances rather than sample variance.
A sample of 5000 data points in the unsupervised representation. (Orange: humpback. Blue: not humpback.)
Given individual "query" units, we retrieved the nearest neighbors in the entire corpus using Euclidean distance between embedding vectors. In some cases we found hundreds more instances of the same unit with good precision.
Manually chosen query units (boxed) and nearest neighbors using the unsupervised representation.
We intend to use these in the future to build a training set for a classifier that discriminates between song units. We could also use them to expand the training set used for learning a humpback detector.

Predictions from the Supervised Classifier on the Entire Dataset
We plotted summaries of the model output grouped by time and location. Not all sites had deployments in all years. Duty cycling (example: 5 minutes on, 15 minutes off) allows longer deployments on limited battery power, but the schedule can vary. To deal with these sources of variability, we consider the proportion of sampled time in which humpback calling was detected to the total time recorded in a month:
Time density of presence on year / month axes for the Kona and Saipan sites.
The apparent seasonal variation is consistent with a known pattern in which humpback populations spend summers feeding near Alaska and then migrate to the vicinity of the Hawaiian Islands to breed and give birth. This is a nice sanity check for the model.

We hope the predictions for the full dataset will equip experts at NOAA to reach deeper insights into the status of these populations and into the degree of any anthropogenic impacts on them. We also hope this is just one of the first few in a series of successes as Google works to accelerate the application of machine learning to the world's biggest humanitarian and environmental challenges.

We would like to thank Ann Allen (NOAA Fisheries) for providing the bulk of the ground truth data, for many useful rounds of feedback, and for some of the words in this post. Karlina Merkens (NOAA affiliate) provided further useful guidance. We also thank the NOAA Pacific Islands Fisheries Science Center as a whole for collecting and sharing the acoustic data.

Within Google, Jiayang Liu, Julie Cattiau, Aren Jansen, Rif A. Saurous, and Lauren Harrell contributed to this work. Special thanks go to Lauren, who designed the plots in the analysis section and implemented them using ggplot.

Source: Google AI Blog

Applying Deep Learning to Metastatic Breast Cancer Detection

A pathologist’s microscopic examination of a tumor in patients is considered the gold standard for cancer diagnosis, and has a profound impact on prognosis and treatment decisions. One important but laborious aspect of the pathologic review involves detecting cancer that has spread (metastasized) from the primary site to nearby lymph nodes. Detection of nodal metastasis is relevant for most cancers, and forms one of the foundations of the widely-used TNM cancer staging.

In breast cancer in particular, nodal metastasis influences treatment decisions regarding radiation therapy, chemotherapy, and the potential surgical removal of additional lymph nodes. As such, the accuracy and timeliness of identifying nodal metastases has a significant impact on clinical care. However, studies have shown that about 1 in 4 metastatic lymph node staging classifications would be changed upon second pathologic review, and detection sensitivity of small metastases on individual slides can be as low as 38% when reviewed under time constraints.

Last year, we described our deep learning–based approach to improve diagnostic accuracy (LYmph Node Assistant, or LYNA) to the 2016 ISBI Camelyon Challenge, which provided gigapixel-sized pathology slides of lymph nodes from breast cancer patients for researchers to develop computer algorithms to detect metastatic cancer. While LYNA achieved significantly higher cancer detection rates (Liu et al. 2017) than had been previously reported, an accurate algorithm alone is insufficient to improve pathologists’ workflow or improve outcomes for breast cancer patients. For patient safety, these algorithms must be tested in a variety of settings to understand their strengths and weaknesses. Furthermore, the actual benefits to pathologists using these algorithms had not been previously explored and must be assessed to determine whether or not an algorithm actually improves efficiency or diagnostic accuracy.

In “Artificial Intelligence Based Breast Cancer Nodal Metastasis Detection: Insights into the Black Box for Pathologists” (Liu et al. 2018), published in the Archives of Pathology and Laboratory Medicine and “Impact of Deep Learning Assistance on the Histopathologic Review of Lymph Nodes for Metastatic Breast Cancer” (Steiner, MacDonald, Liu et al. 2018) published in The American Journal of Surgical Pathology, we present a proof-of-concept pathologist assistance tool based on LYNA, and investigate these factors.

In the first paper, we applied our algorithm to de-identified pathology slides from both the Camelyon Challenge and an independent dataset provided by our co-authors at the Naval Medical Center San Diego. Because this additional dataset consisted of pathology samples from a different lab using different processes, it improved the representation of the diversity of slides and artifacts seen in routine clinical practice. LYNA proved robust to image variability and numerous histological artifacts, and achieved similar performance on both datasets without additional development.
Left: sample view of a slide containing lymph nodes, with multiple artifacts: the dark zone on the left is an air bubble, the white streaks are cutting artifacts, the red hue across some regions are hemorrhagic (containing blood), the tissue is necrotic (decaying), and the processing quality was poor. Right: LYNA identifies the tumor region in the center (red), and correctly classifies the surrounding artifact-laden regions as non-tumor (blue).
In both datasets, LYNA was able to correctly distinguish a slide with metastatic cancer from a slide without cancer 99% of the time. Further, LYNA was able to accurately pinpoint the location of both cancers and other suspicious regions within each slide, some of which were too small to be consistently detected by pathologists. As such, we reasoned that one potential benefit of LYNA could be to highlight these areas of concern for pathologists to review and determine the final diagnosis.

In our second paper, 6 board-certified pathologists completed a simulated diagnostic task in which they reviewed lymph nodes for metastatic breast cancer both with and without the assistance of LYNA. For the often laborious task of detecting small metastases (termed micrometastases), the use of LYNA made the task subjectively “easier” (according to pathologists’ self-reported diagnostic difficulty) and halved average slide review time, requiring about one minute instead of two minutes per slide.
Left: sample views of a slide containing lymph nodes with a small metastatic breast tumor at progressively higher magnifications. Right: the same views when shown with algorithmic “assistance” (LYmph Node Assistant, LYNA) outlining the tumor in cyan.
This suggests the intriguing potential for assistive technologies such as LYNA to reduce the burden of repetitive identification tasks and to allow more time and energy for pathologists to focus on other, more challenging clinical and diagnostic tasks. In terms of diagnostic accuracy, pathologists in this study were able to more reliably detect micrometastases with LYNA, reducing the rate of missed micrometastases by a factor of two. Encouragingly, pathologists with LYNA assistance were more accurate than either unassisted pathologists or the LYNA algorithm itself, suggesting that people and algorithms can work together effectively to perform better than either alone.

With these studies, we have made progress in demonstrating the robustness of our LYNA algorithm to support one component of breast cancer TNM staging, and assessing its impact in a proof-of-concept diagnostic setting. While encouraging, the bench-to-bedside journey to help doctors and patients with these types of technologies is a long one. These studies have important limitations, such as limited dataset sizes and a simulated diagnostic workflow which examined only a single lymph node slide for every patient instead of the multiple slides that are common for a complete clinical case. Further work will be needed to assess the impact of LYNA on real clinical workflows and patient outcomes. However, we remain optimistic that carefully validated deep learning technologies and well-designed clinical tools can help improve both the accuracy and availability of pathologic diagnosis around the world.

Source: Google AI Blog

Open Sourcing Active Question Reformulation with Reinforcement Learning

Natural language understanding is a significant ongoing focus of Google’s AI research, with application to machine translation, syntactic and semantic parsing, and much more. Importantly, as conversational technology increasingly requires the ability to directly answer users’ questions, one of the most active areas of research we pursue is question answering (QA), a fundamental building block of human dialogue.

Because open sourcing code is a critical component of reproducible research, we are releasing a TensorFlow package for Active Question Answering (ActiveQA), a research project that investigates using reinforcement learning to train artificial agents for question answering. Introduced for the first time in our ICLR 2018 paper “Ask the Right Questions: Active Question Reformulation with Reinforcement Learning”, ActiveQA interacts with QA systems using natural language with the goal of providing better answers.

Active Question Answering
In traditional QA, supervised learning techniques are used in combination with labeled data to train a system that answers arbitrary input questions. While this is effective, it suffers from a lack of ability to deal with uncertainty like humans would, by reformulating questions, issuing multiple searches, evaluating and aggregating responses. Inspired by humans’ ability to "ask the right questions", ActiveQA introduces an agent that repeatedly consults the QA system. In doing so, the agent may reformulate the original question multiple times in order to find the best possible answer. We call this approach active because the agent engages in a dynamic interaction with the QA system, with the goal of improving the quality of the answers returned.

For example, consider the question “When was Tesla born?”. The agent reformulates the question in two different ways: “When is Tesla’s birthday” and “Which year was Tesla born”, retrieving answers to both questions from the QA system. Using all this information it decides to return “July 10 1856”.
What characterizes an ActiveQA system is that it learns to ask questions that lead to good answers. However, because training data in the form of question pairs, with an original question and a more successful variant, is not readily available, ActiveQA uses reinforcement learning, an approach to machine learning concerned with training agents so that they take actions that maximize a reward, while interacting with an environment.

The learning takes place as the ActiveQA agent interacts with the QA system; each question reformulation is evaluated in terms of how good the corresponding answer is, which constitutes the reward. If the answer is good, then the learning algorithm will adjust the model’s parameters so that the question reformulation that lead to the answer is more likely to be generated again, or otherwise less likely, if the answer was bad.

In our paper, we show that it is possible to train such agents to outperform the underlying QA system, the one used to provide answers to reformulations, by asking better questions. This is an important result, as the QA system is already trained with supervised learning to solve the same task. Another compelling finding of our research is that the ActiveQA agent can learn a fairly sophisticated, and still somewhat interpretable, reformulation strategy (the policy in reinforcement learning). The learned policy uses well-known information retrieval techniques such as tf-idf query term re-weighting, the process by which more informative terms are weighted more than generic ones, and word stemming.

Build Your Own ActiveQA System
The TensorFlow ActiveQA package we are releasing consists of three main components, and contains all the code necessary to train and run the ActiveQA agent.
  • A pretrained sequence to sequence model that takes as input a question and returns its reformulations. This task is similar to machine translation, translating from English to English, and indeed the initial model can be used for general paraphrasing. For its implementation we use and customize the TensorFlow Neural Machine Translation Tutorial code. We adapted the code to support training with reinforcement learning, using policy gradient methods.*
  • An answer selection model. The answer selector uses a convolutional neural network and assigns a score to each triplet of original question, reformulation and answer. The selector uses pre-trained, publicly available word embeddings (GloVe).
  • A question answering system (the environment). For this purpose we use BiDAF, a popular question answering system, described in Seo et al. (2017).
We also provide pointers to checkpoints for all the trained models.

Google’s mission is to organize the world's information and make it universally accessible and useful, and we believe that ActiveQA is an important step in realizing that mission. We envision that this research will help us design systems that provide better and more interpretable answers, and hope it will help others develop systems that can interact with the world using natural language.

Contributors to this research and release include Alham Fikri Aji, Christian Buck, Jannis Bulian, Massimiliano Ciaramita, Wojciech Gajewski, Andrea Gesmundo, Alexey Gronskiy, Neil Houlsby, Yannic Kilcher, and Wei Wang.

* The system we reported on in our paper used the TensorFlow sequence-to-sequence code used in Britz et al. (2017). Later, an open source version of the Google Translation model (GNMT) was published as a tutorial. The ActiveQA version released today is based on this more recent, and actively developed implementation. For this reason the released system varies slightly from the paper’s. Nevertheless, the performance and behavior are qualitatively and quantitatively comparable.

Source: Google AI Blog

Google’s Next Generation Music Recognition

In 2017 we launched Now Playing on the Pixel 2, using deep neural networks to bring low-power, always-on music recognition to mobile devices. In developing Now Playing, our goal was to create a small, efficient music recognizer which requires a very small fingerprint for each track in the database, allowing music recognition to be run entirely on-device without an internet connection. As it turns out, Now Playing was not only useful for an on-device music recognizer, but also greatly exceeded the accuracy and efficiency of our then-current server-side system, Sound Search, which was built before the widespread use of deep neural networks. Naturally, we wondered if we could bring the same technology that powers Now Playing to the server-side Sound Search, with the goal of making Google’s music recognition capabilities the best in the world.

Recently, we introduced a new version of Sound Search that is powered by some of the same technology used by Now Playing. You can use it through the Google Search app or the Google Assistant on any Android phone. Just start a voice query, and if there’s music playing near you, a “What’s this song?” suggestion will pop up for you to press. Otherwise, you can just ask, “Hey Google, what’s this song?” With this latest version of Sound Search, you’ll get faster, more accurate results than ever before!
Now Playing versus Sound Search
Now Playing miniaturized music recognition technology such that it was small and efficient enough to be run continuously on a mobile device without noticeable battery impact. To do this we developed an entirely new system using convolutional neural networks to turn a few seconds of audio into a unique “fingerprint.” This fingerprint is then compared against an on-device database holding tens of thousands of songs, which is regularly updated to add newly released tracks and remove those that are no longer popular. In contrast, the server-side Sound Search system is very different, having to match against ~1000x as many songs as Now Playing. Making Sound Search both faster and more accurate with a substantially larger musical library presented several unique challenges. But before we go into that, a few details on how Now Playing works.

The Core Matching Process of Now Playing
Now Playing generates the musical “fingerprint” by projecting the musical features of an eight-second portion of audio into a sequence of low-dimensional embedding spaces consisting of seven two-second clips at 1 second intervals, giving a segmentation like this:
Now Playing then searches the on-device song database, which was generated by processing popular music with the same neural network, for similar embedding sequences. The database search uses a two phase algorithm to identify matching songs, where the first phase uses a fast but inaccurate algorithm which searches the whole song database to find a few likely candidates, and the second phase does a detailed analysis of each candidate to work out which song, if any, is the right one.
  • Matching, phase 1: Finding good candidates: For every embedding, Now Playing performs a nearest neighbor search on the on-device database of songs for similar embeddings. The database uses a hybrid of spatial partitioning and vector quantization to efficiently search through millions of embedding vectors. Because the audio buffer is noisy, this search is approximate, and not every embedding will find a nearby match in the database for the correct song. However, over the whole clip, the chances of finding several nearby embeddings for the correct song are very high, so the search is narrowed to a small set of songs which got multiple hits.
  • Matching, phase 2: Final matching: Because the database search used above is approximate, Now Playing may not find song embeddings which are nearby to some embeddings in our query. Therefore, in order to calculate an accurate similarity score, Now Playing retrieves all embeddings for each song in the database which might be relevant to fill in the “gaps”. Then, given the sequence of embeddings from the audio buffer and another sequence of embeddings from a song in the on-device database, Now Playing estimates their similarity pairwise and adds up the estimates to get the final matching score.
It’s critical to the accuracy of Now Playing to use a sequence of embeddings rather than a single embedding. The fingerprinting neural network is not accurate enough to allow identification of a song from a single embedding alone — each embedding will generate a lot of false positive results. However, combining the results from multiple embeddings allows the false positives to be easily removed, as the correct song will be a match to every embedding, while false positive matches will only be close to one or two embeddings from the input audio.

Scaling up Now Playing for the Sound Search server
So far, we’ve gone into some detail of how Now Playing matches songs to an on-device database. The biggest challenge in going from Now Playing, with tens of thousands of songs, to Sound Search, with tens of millions, is that there are a thousand times as many songs which could give a false positive result. To compensate for this without any other changes, we would have to increase the recognition threshold, which would mean needing more audio to get a confirmed match. However, the goal of the new Sound Search server was to be able to match faster, not slower, than Now Playing, so we didn’t want people to wait 10+ seconds for a result.

As Sound Search is a server-side system, it isn’t limited by processing and storage constraints in the same way Now Playing is. Therefore, we made two major changes to how we do fingerprinting, both of which increased accuracy at the expense of server resources:
  • We quadrupled the size of the neural network used, and increased each embedding from 96 to 128 dimensions, which reduces the amount of work the neural network has to do to pack the high-dimensional input audio into a low-dimensional embedding. This is critical in improving the quality of phase two, which is very dependent on the accuracy of the raw neural network output.
  • We doubled the density of our embeddings — it turns out that fingerprinting audio every 0.5s instead of every 1s doesn’t reduce the quality of the individual embeddings very much, and gives us a huge boost by doubling the number of embeddings we can use for the match.
We also decided to weight our index based on song popularity - in effect, for popular songs, we lower the matching threshold, and we raise it for obscure songs. Overall, this means that we can keep adding more (obscure) songs almost indefinitely to our database without slowing our recognition speed too much.

With Now Playing, we originally set out to use machine learning to create a robust audio fingerprint compact enough to run entirely on a phone. It turned out that we had, in fact, created a very good all-round audio fingerprinting system, and the ideas developed there carried over very well to the server-side Sound Search system, even though the challenges of Sound Search are quite different.

We still think there’s room for improvement though — we don’t always match when music is very quiet or in very noisy environments, and we believe we can make the system even faster. We are continuing to work on these challenges with the goal of providing the next generation in music recognition. We hope you’ll try it the next time you want to find out what song is playing! You can put a shortcut on your home screen like this:
We would like to thank Micha Riser, Mihajlo Velimirovic, Marvin Ritter, Ruiqi Guo, Sanjiv Kumar, Stephen Wu, Diego Melendo Casado‎, Katia Naliuka, Jason Sanders, Beat Gfeller, Christian Frank, Dominik Roblek, Matt Sharifi and Blaise Aguera y Arcas‎.

Source: Google AI Blog

Introducing the Unrestricted Adversarial Examples Challenge

Machine learning is being deployed in more and more real-world applications, including medicine, chemistry and agriculture. When it comes to deploying machine learning in safety-critical contexts, significant challenges remain. In particular, all known machine learning algorithms are vulnerable to adversarial examples — inputs that an attacker has intentionally designed to cause the model to make a mistake. While previous research on adversarial examples has mostly focused on investigating mistakes caused by small modifications in order to develop improved models, real-world adversarial agents are often not subject to the “small modification” constraint. Furthermore, machine learning algorithms can often make confident errors when faced with an adversary, which makes the development of classifiers that don’t make any confident mistakes, even in the presence of an adversary which can submit arbitrary inputs to try to fool the system, an important open problem.

Today we're announcing the Unrestricted Adversarial Examples Challenge, a community-based challenge to incentivize and measure progress towards the goal of zero confident classification errors in machine learning models. While previous research has focused on adversarial examples that are restricted to small changes to pre-labeled data points (allowing researchers to assume the image should have the same label after a small perturbation), this challenge allows unrestricted inputs, allowing participants to submit arbitrary images from the target classes to develop and test models on a wider variety of adversarial examples.
Adversarial examples can be generated through a variety of means, including by making small modifications to the input pixels, but also using spatial transformations, or simple guess-and-check to find misclassified inputs.
Structure of the Challenge
Participants can submit entries one of two roles: as a defender, by submitting a classifier which has been designed to be difficult to fool, or as an attacker, by submitting arbitrary inputs to try to fool the defenders' models. In a “warm-up” period before the challenge, we will present a set of fixed attacks for participants to design networks to defend against. After the community can conclusively beat those fixed attacks, we will launch the full two-sided challenge with prizes for both attacks and defenses.

For the purposes of this challenge, we have created a simple “bird-or-bicycle” classification task, where a classifier must answer the following: “Is this an unambiguous picture of a bird, a bicycle, or is it ambiguous / not obvious?” We selected this task because telling birds and bicycles apart is very easy for humans, but all known machine learning techniques struggle at the task when in the presence of an adversary.

The defender's goal is to correctly label a clean test set of birds and bicycles with high accuracy, while also making no confident errors on any attacker-provided bird or bicycle image. The attacker's goal is to find an image of a bird that the defending classifier confidently labels as a bicycle (or vice versa). We want to make the challenge as easy as possible for the defenders, so we discard all images that are ambiguous (such as a bird riding a bicycle) or not obvious (such as an aerial view of a park, or random noise).
Examples of ambiguous and unambiguous images. Defenders must make no confident mistakes on unambiguous bird or bicycle images. We discard all images that humans find ambiguous or not obvious. All images under CC licenses 1, 2, 3, 4.
Attackers may submit absolutely any image of a bird or a bicycle in an attempt to fool the defending classifier. For example, an attacker could take photographs of birds, use 3D rendering software, make image composites using image editing software, produce novel bird images with a generative model, or any other technique.

In order to validate new attacker-provided images, we ask an ensemble of humans to label the image. This procedure lets us allow attackers to submit arbitrary images, not just test set images modified in small ways. If the defending classifier confidently classifies as "bird" any attacker-provided image which the human labelers unanimously labeled as a bicycle, the defending model has been broken. You can learn more details about the structure of the challenge in our paper.

How to Participate
If you’re interested in participating, guidelines for getting started can be found on the project on github. We’ve already released our dataset, the evaluation pipeline, and baseline attacks for the warm-up, and we’ll be keeping an up-to-date leaderboard with the best defenses from the community. We look forward to your entries!

The team behind the Unrestricted Adversarial Examples Challenge includes Tom Brown, Catherine Olsson, Nicholas Carlini, Chiyuan Zhang, and Ian Goodfellow from Google, and Paul Christiano from OpenAI.

Source: Google AI Blog

Moving Beyond Translation with the Universal Transformer

Last year we released the Transformer, a new machine learning model that showed remarkable success over existing algorithms for machine translation and other language understanding tasks. Before the Transformer, most neural network based approaches to machine translation relied on recurrent neural networks (RNNs) which operate sequentially (e.g. translating words in a sentence one-after-the-other) using recurrence (i.e. the output of each step feeds into the next). While RNNs are very powerful at modeling sequences, their sequential nature means that they are quite slow to train, as longer sentences need more processing steps, and their recurrent structure also makes them notoriously difficult to train properly.

In contrast to RNN-based approaches, the Transformer used no recurrence, instead processing all words or symbols in the sequence in parallel while making use of a self-attention mechanism to incorporate context from words farther away. By processing all words in parallel and letting each word attend to other words in the sentence over multiple processing steps, the Transformer was much faster to train than recurrent models. Remarkably, it also yielded much better translation results than RNNs. However, on smaller and more structured language understanding tasks, or even simple algorithmic tasks such as copying a string (e.g. to transform an input of “abc” to “abcabc”), the Transformer does not perform very well. In contrast, models that perform well on these tasks, like the Neural GPU and Neural Turing Machine, fail on large-scale language understanding tasks like translation.

In “Universal Transformers” we extend the standard Transformer to be computationally universal (Turing complete) using a novel, efficient flavor of parallel-in-time recurrence which yields stronger results across a wider range of tasks. We built on the parallel structure of the Transformer to retain its fast training speed, but we replaced the Transformer’s fixed stack of different transformation functions with several applications of a single, parallel-in-time recurrent transformation function (i.e. the same learned transformation function is applied to all symbols in parallel over multiple processing steps, where the output of each step feeds into the next). Crucially, where an RNN processes a sequence symbol-by-symbol (left to right), the Universal Transformer processes all symbols at the same time (like the Transformer), but then refines its interpretation of every symbol in parallel over a variable number of recurrent processing steps using self-attention. This parallel-in-time recurrence mechanism is both faster than the serial recurrence used in RNNs, and also makes the Universal Transformer more powerful than the standard feedforward Transformer.
The Universal Transformer repeatedly refines a series of vector representations (shown as h1 to hm) for each position of the sequence in parallel, by combining information from different positions using self-attention and applying a recurrent transition function. Arrows denote dependencies between operations.
At each step, information is communicated from each symbol (e.g. word in the sentence) to all other symbols using self-attention, just like in the original Transformer. However, now the number of times this transformation is applied to each symbol (i.e. the number of recurrent steps) can either be manually set ahead of time (e.g. to some fixed number or to the input length), or it can be decided dynamically by the Universal Transformer itself. To achieve the latter, we added an adaptive computation mechanism to each position which can allocate more processing steps to symbols that are more ambiguous or require more computations.

As an intuitive example of how this could be useful, consider the sentence “I arrived at the bank after crossing the river”. In this case, more context is required to infer the most likely meaning of the word “bank” compared to the less ambiguous meaning of “I” or “river”. When we encode this sentence using the standard Transformer, the same amount of computation is applied unconditionally to each word. However, the Universal Transformer’s adaptive mechanism allows the model to spend increased computation only on the more ambiguous words, e.g. to use more steps to integrate the additional contextual information needed to disambiguate the word “bank”, while spending potentially fewer steps on less ambiguous words.

At first it might seem restrictive to allow the Universal Transformer to only apply a single learned function repeatedly to process its input, especially when compared to the standard Transformer which learns to apply a fixed sequence of distinct functions. But learning how to apply a single function repeatedly means the number of applications (processing steps) can now be variable, and this is the crucial difference. Beyond allowing the Universal Transformer to apply more computation to more ambiguous symbols, as explained above, it further allows the model to scale the number of function applications with the overall size of the input (more steps for longer sequences), or to decide dynamically how often to apply the function to any given part of the input based on other characteristics learned during training. This makes the Universal Transformer more powerful in a theoretical sense, as it can effectively learn to apply different transformations to different parts of the input. This is something that the standard Transformer cannot do, as it consists of fixed stacks of learned Transformation blocks applied only once.

But while increased theoretical power is desirable, we also care about empirical performance. Our experiments confirm that Universal Transformers are indeed able to learn from examples how to copy and reverse strings and how to perform integer addition much better than a Transformer or an RNN (although not quite as well as Neural GPUs). Furthermore, on a diverse set of challenging language understanding tasks the Universal Transformer generalizes significantly better and achieves a new state of the art on the bAbI linguistic reasoning task and the challenging LAMBADA language modeling task. But perhaps of most interest is that the Universal Transformer also improves translation quality by 0.9 BLEU1 over a base Transformer with the same number of parameters, trained in the same way on the same training data. Putting things in perspective, this almost adds another 50% relative improvement on top of the previous 2.0 BLEU improvement that the original Transformer showed over earlier models when it was released last year.

The Universal Transformer thus closes the gap between practical sequence models competitive on large-scale language understanding tasks such as machine translation, and computationally universal models such as the Neural Turing Machine or the Neural GPU, which can be trained using gradient descent to perform arbitrary algorithmic tasks. We are enthusiastic about recent developments on parallel-in-time sequence models, and in addition to adding computational capacity and recurrence in processing depth, we hope that further improvements to the basic Universal Transformer presented here will help us build learning algorithms that are both more powerful, more data efficient, and that generalize beyond the current state-of-the-art.

If you’d like to try this for yourself, the code used to train and evaluate Universal Transformers can be found here in the open-source Tensor2Tensor repository.

This research was conducted by Mostafa Dehghani, Stephan Gouws, Oriol Vinyals, Jakob Uszkoreit, and Łukasz Kaiser. Additional thanks go to Ashish Vaswani, Douglas Eck, and David Dohan for their fruitful comments and inspiration.

1 A translation quality benchmark widely used in the machine translation community, computed on the standard WMT newstest2014 English to German translation test data set.

Source: Google AI Blog

MnasNet: Towards Automating the Design of Mobile Machine Learning Models

Convolutional neural networks (CNNs) have been widely used in image classification, face recognition, object detection and many other domains. Unfortunately, designing CNNs for mobile devices is challenging because mobile models need to be small and fast, yet still accurate. Although significant effort has been made to design and improve mobile models, such as MobileNet and MobileNetV2, manually creating efficient models remains challenging when there are so many possibilities to consider. Inspired by recent progress in AutoML neural architecture search, we wondered if the design of mobile CNN models could also benefit from an AutoML approach.

In “MnasNet: Platform-Aware Neural Architecture Search for Mobile”, we explore an automated neural architecture search approach for designing mobile models using reinforcement learning. To deal with mobile speed constraints, we explicitly incorporate the speed information into the main reward function of the search algorithm, so that the search can identify a model that achieves a good trade-off between accuracy and speed. In doing so, MnasNet is able to find models that run 1.5x faster than state-of-the-art hand-crafted MobileNetV2 and 2.4x faster than NASNet, while reaching the same ImageNet top 1 accuracy.

Unlike in previous architecture search approaches, where model speed is considered via another proxy (e.g., FLOPS), our approach directly measures model speed by executing the model on a particular platform, e.g., Pixel phones which were used in this research study. In this way, we can directly measure what is achievable in real-world practice, given that each type of mobile devices has its own software and hardware idiosyncrasies and may require different architectures for the best trade-offs between accuracy and speed.

The overall flow of our approach consists mainly of three components: a RNN-based controller for learning and sampling model architectures, a trainer that builds and trains models to obtain the accuracy, and an inference engine for measuring the model speed on real mobile phones using TensorFlow Lite. We formulate a multi-objective optimization problem that aims to achieve both high accuracy and high speed, and utilize a reinforcement learning algorithm with a customized reward function to find Pareto optimal solutions (e.g., models that have the highest accuracy without worsening speed).
Overall flow of our automated neural architecture search approach for Mobile.
In order to strike the right balance between search flexibility and search space size, we propose a novel factorized hierarchical search space, which factorizes a convolutional neural network into a sequence of blocks, and then uses a hierarchical search space to determine the layer architecture for each block. In this way, our approach allows different layers to use different operations and connections; Meanwhile, we force all layers in each block to share the same structure, thus significantly reducing the search space size by orders of magnitude compared to a flat per-layer search space.
Our MnasNet network, sampled from the novel factorized hierarchical search space,illustrating the layer diversity throughout the network architecture.
We tested the effectiveness of our approach on ImageNet classification and COCO object detection. Our experiments achieve a new state-of-the-art accuracy under typical mobile speed constraints. In particular, the figure below shows the results on ImageNet.
ImageNet Accuracy and Inference Latency comparison. MnasNets are our models.
With the same accuracy, our MnasNet model runs 1.5x faster than the hand-crafted state-of-the-art MobileNetV2, and 2.4x faster than NASNet, which also used architecture search. After applying the squeeze-and-excitation optimization, our MnasNet+SE models achieve ResNet-50 level top-1 accuracy at 76.1%, with 19x fewer parameters and 10x fewer multiply-adds operations. On COCO object detection, our model family achieve both higher accuracy and higher speed over MobileNet, and achieves comparable accuracy to the SSD300 model with 35x less computation cost.

We are pleased to see that our automated approach can achieve state-of-the-art performance on multiple complex mobile vision tasks. In future, we plan to incorporate more operations and optimizations into our search space, and apply it to more mobile vision tasks such as semantic segmentation.

Special thanks to the co-authors of the paper Bo Chen, Quoc V. Le, Ruoming Pang and Vijay Vasudevan. We’d also like to thank Andrew Howard, Barret Zoph, Dmitry Kalenichenko, Guiheng Zhou, Jeff Dean, Mark Sandler, Megan Kacholia, Sheng Li, Vishy Tirumalashetty, Wen Wang, Xiaoqiang Zheng and Yifeng Lu for their help, and the TensorFlow Lite and Google Brain teams.

Source: Google AI Blog