Posted by Stephan Hoyer, Software Engineer, Google Research
The world’s fastest supercomputers were designed for modeling physical phenomena, yet they still are not fast enough to robustly predict the impacts of climate change, to design controls for airplanes based on airflow or to accurately simulate a fusion reactor. All of these phenomena are modeled by partial differential equations (PDEs), the class of equations that describe everything smooth and continuous in the physical world, and the most common class of simulation problems in science and engineering. To solve these equations, we need faster simulations, but in recent years, Moore’s law has been slowing. At the same time, we’ve seen huge breakthroughs in machine learning (ML) along with faster hardware optimized for it. What does this new paradigm offer for scientific computing?
For most real-world problems, closed-form solutions to PDEs don’t exist. Instead, one must find discrete equations (“discretizations”) that a computer can solve to approximate the continuous PDE. Typical approaches to solve PDEs represent equations on a grid, e.g., using finite differences. To achieve convergence, the mesh spacing of the grid needs to be smaller than the smallest feature size of the solutions. This often isn’t feasible because of an unfortunate scaling law: achieving 10x higher resolution requires 10,000x more compute, because the grid must be scaled in four dimensions—three spatial dimensions and time. Instead, in our paper we show that ML can be used to learn better representations for PDEs on coarser grids.
The challenge is to retain the accuracy of high-resolution simulations while still using the coarsest grid possible. In our work we’re able to improve upon existing schemes by replacing heuristics based on deep human insight (e.g., “solutions to a PDE should always be smooth away from discontinuities”) with optimized rules based on machine learning. The rules our ML models recover are complex, and we don’t entirely understand them, but they incorporate sophisticated physical principles like the idea of “upwinding”—to accurately model what’s coming towards you in a fluid flow, you should look upstream in the direction the wind is coming from. An example of our results on a simple model of fluid dynamics are shown below:
Simulations of Burgers’ equation, a model for shock waves in fluids, solved with either a standard finite volume method (left) or our neural network based method (right). The orange squares represent simulations with each method on low resolution grids. These points are fed back into the model at each time step, which then predicts how they should change. Blue lines show the exact simulations used for training. The neural network solution is much better, even on a 4x coarser grid, as indicated by the orange squares smoothly tracing the blue line.
Our research also illustrates a broader lesson about how to effectively combine machine learning and physics. Rather than attempting to learn physics from scratch, we combined neural networks with components from traditional simulation methods, including the known form of the equations we’re solving and finite volume methods. This means that laws such as conservation of momentum are exactly satisfied, by construction, and allows our machine learning models to focus on what they do best, learning optimal rules for interpolation in complex, high-dimensional spaces.
Next Steps We are focused on scaling up the techniques outlined in our paper to solve larger scale simulation problems with real-world impacts, such as weather and climate prediction. We’re excited about the broad potential of blending machine learning into the complex algorithms of scientific computing.
Acknowledgments Thanks to co-authors Yohai Bar-Sinari, Jason Hickey and Michael Brenner; and Google collaborators Peyman Milanfar, Pascal Getreur, Ignacio Garcia Dorado, Dmitrii Kochkov, Jiawei Zhuang and Anton Geraschenko.
Posted by Narayan Hegde, Software Engineer, Google Health and Carrie J. Cai, Research Scientist, Google Research
Advances in machine learning (ML) have shown great promise for assisting in the work of healthcare professionals, such as aiding the detection of diabetic eye disease and metastatic breast cancer. Though high-performing algorithms are necessary to gain the trust and adoption of clinicians, they are not always sufficient—what information is presented to doctors and how doctors interact with that information can be crucial determinants in the utility that ML technology ultimately has for users.
The medical specialty of anatomic pathology, which is the gold standard for the diagnosis of cancer and many other diseases through microscopic analysis of tissue samples, can greatly benefit from applications of ML. Though diagnosis through pathology is traditionally done on physical microscopes, there has been a growing adoption of “digital pathology,” where high-resolution images of pathology samples can be examined on a computer. With this movement comes the potential to much more easily look up information, as is needed when pathologists tackle the diagnosis of difficult cases or rare diseases, when “general” pathologists approach specialist cases, and when trainee pathologists are learning. In these situations, a common question arises, “What is this feature that I’m seeing?” The traditional solution is for doctors to ask colleagues, or to laboriously browse reference textbooks or online resources, hoping to find an image with similar visual characteristics. The general computer vision solution to problems like this is termed content-based image retrieval (CBIR), one example of which is the “reverse image search” feature in Google Images, in which users can search for similar images by using another image as input.
SMILY Design The first step in developing SMILY was to apply a deep learning model, trained using 5 billion natural, non-pathology images (e.g., dogs, trees, man-made objects, etc.), to compress images into a “summary” numerical vector, called an embedding. The network learned during the training process to distinguish similar images from dissimilar ones by computing and comparing their embeddings. This model is then used to create a database of image patches and their associated embeddings using a corpus of de-identified slides from The Cancer Genome Atlas. When a query image patch is selected in the SMILY tool, the query patch’s embedding is similarly computed and compared with the database to retrieve the image patches with the most similar embeddings.
Schematic of the steps in building the SMILY database and the process by which input image patches are used to perform the similar image search.
The tool allows a user to select a region of interest, and obtain visually-similar matches. We tested SMILY’s ability to retrieve images along a pre-specified axis of similarity (e.g. histologic feature or tumor grade), using images of tissue from the breast, colon, and prostate (3 of the most common cancer sites). We found that SMILY demonstrated promising results despite not being trained specifically on pathology images or using any labeled examples of histologic features or tumor grades.
Example of selecting a small region in a slide and using SMILY to retrieve similar images. SMILY efficiently searches a database of billions of cropped images in a few seconds. Because pathology images can be viewed at different magnifications (zoom levels), SMILY automatically searches images at the same magnification as the input image.
Second example of using SMILY, this time searching for a lobular carcinoma, a specific subtype of breast cancer.
Refinement tools for SMILY However, a problem emerged when we observed how pathologists interacted with SMILY. Specifically, users were trying to answer the nebulous question of “What looks similar to this image?” so that they could learn from past cases containing similar images. Yet, there was no way for the tool to understand the intent of the search: Was the user trying to find images that have a similar histologic feature, glandular morphology, overall architecture, or something else? In other words, users needed the ability to guide and refine the search results on a case-by-case basis in order to actually find what they were looking for. Furthermore, we observed that this need for iterative search refinement was rooted in how doctors often perform “iterative diagnosis”—by generating hypotheses, collecting data to test these hypotheses, exploring alternative hypotheses, and revisiting or retesting previous hypotheses in an iterative fashion. It became clear that, for SMILY to meet real user needs, it would need to support a different approach to user interaction.
Through careful human-centered research described in our second paper, we designed and augmented SMILY with a suite of interactive refinement tools that enable end-users to express what similarity means on-the-fly: 1) refine-by-region allows pathologists to crop a region of interest within the image, limiting the search to just that region; 2) refine-by-example gives users the ability to pick a subset of the search results and retrieve more results like those; and 3) refine-by-concept sliders can be used to specify that more or less of a clinical concept be present in the search results (e.g., fused glands). Rather than requiring that these concepts be built into the machine learning model, we instead developed a method that enables end-users to create new concepts post-hoc, customizing the search algorithm towards concepts they find important for each specific use case. This enables new explorations via post-hoc tools after a machine learning model has already been trained, without needing to re-train the original model for each concept or application of interest.
Through our user study with pathologists, we found that the tool-based SMILY not only increased the clinical usefulness of search results, but also significantly increased users’ trust and likelihood of adoption, compared to a conventional version of SMILY without these tools. Interestingly, these refinement tools appeared to have supported pathologists’ decision-making process in ways beyond simply performing better on similarity searches. For example, pathologists used the observed changes to their results from iterative searches as a means of progressively tracking the likelihood of a hypothesis. When search results were surprising, many re-purposed the tools to test and understand the underlying algorithm, for example, by cropping out regions they thought were interfering with the search or by adjusting the concept sliders to increase the presence of concepts they suspected were being ignored. Beyond being passive recipients of ML results, doctors were empowered with the agency to actively test hypotheses and apply their expert domain knowledge, while simultaneously leveraging the benefits of automation.
With these interactive tools enabling users to tailor each search experience to their desired intent, we are excited for SMILY’s potential to assist with searching large databases of digitized pathology images. One potential application of this technology is to index textbooks of pathology images with descriptive captions, and enable medical students or pathologists in training to search these textbooks using visual search, speeding up the educational process. Another application is for cancer researchers interested in studying the correlation of tumor morphologies with patient outcomes, to accelerate the search for similar cases. Finally, pathologists may be able to leverage tools like SMILY to locate all occurrences of a feature (e.g. signs of active cell division, or mitosis) in the same patient’s tissue sample to better understand the severity of the disease to inform cancer therapy decisions. Importantly, our findings add to the body of evidence that sophisticated machine learning algorithms need to be paired with human-centered design and interactive tooling in order to be most useful.
Acknowledgements This work would not have been possible without Jason D. Hipp, Yun Liu, Emily Reif, Daniel Smilkov, Michael Terry, Craig H. Mermel, Martin C. Stumpe and members of Google Health and PAIR. Preprints of the two papers are available here and here.
Posted by Fadi Biadsy, Research Scientist and Ron Weiss, Software Engineer, Google Research
Most people take for granted that when they speak, they will be heard and understood. But for the millions who live with speech impairments caused by physical or neurological conditions, trying to communicate with others can be difficult and lead to frustration. While there have been a great number of recent advances in automatic speech recognition (ASR; a.k.a. speech-to-text) technologies, these interfaces can be inaccessible for those with speech impairments. Further, applications that rely on speech recognition as input for text-to-speech synthesis (TTS) can exhibit word substitution, deletion, and insertion errors. Critically, in today’s technological environment, limited access to speech interfaces, such as digital assistants that depend on directly understanding one's speech, means being excluded from state-of-the-art tools and experiences, widening the gap between what those with and without speech impairments can access.
Project Euphonia has demonstrated that speech recognition models can be significantly improved to better transcribe a variety of atypical and dysarthric speech. Today, we are presenting Parrotron, an ongoing research project that continues and extends our effort to build speech technologies to help those with impaired or atypical speech to be understood by both people and devices. Parrotron consists of a single end-to-enddeep neural network trained to convert speech from a speaker with atypical speech patterns directly into fluent synthesized speech, without an intermediate step of generating text—skipping speech recognition altogether. Parrotron’s approach is speech-centric, looking at the problem only from the point of view of speech signals—e.g., without visual cues such as lip movements. Through this work, we show that Parrotron can help people with a variety of atypical speech patterns—including those with ALS, deafness, and muscular dystrophy—to be better understood in both human-to-human interactions and by ASR engines. The Parrotron Speech Conversion Model Parrotron is an attention-based sequence-to-sequence model trained in two phases using parallel corpora of input/output speech pairs. First, we build a general speech-to-speech conversion model for standard fluent speech, followed by a personalization phase that adjusts the model parameters to the atypical speech patterns from the target speaker. The primary challenge in such a configuration lies in the collection of the parallel training data needed for supervised training, which consists of utterances spoken by many speakers and mapped to the same output speech content spoken by a single speaker. Since it is impractical to have a single speaker record the many hours of training data needed to build a high quality model, Parrotron uses parallel data automatically derived with a TTS system. This allows us to make use of a pre-existing anonymized, transcribed speech recognition corpus to obtain training targets.
The first training phase uses a corpus of ~30,000 hours that consists of millions of anonymized utterance pairs. Each pair includes a natural utterance paired with an automatically synthesized speech utterance that results from running our state-of-the-art Parallel WaveNet TTS system on the transcript of the first. This dataset includes utterances from thousands of speakers spanning hundreds of dialects/accents and acoustic conditions, allowing us to model a large variety of voices, linguistic and non-linguistic contents, accents, and noise conditions with “typical” speech all in the same language. The resulting conversion model projects away all non-linguistic information, including speaker characteristics, and retains only what is being said, not who, where, or how it is said. This base model is used to seed the second personalization phase of training.
The second training phase utilizes a corpus of utterance pairs generated in the same manner as the first dataset. In this case, however, the corpus is used to adapt the network to the acoustic/phonetic, phonotactic and language patterns specific to the input speaker, which might include, for example, learning how the target speaker alters, substitutes, and reduces or removes certain vowels or consonants. To model ALS speech characteristics in general, we use utterances taken from an ALS speech corpus derived from Project Euphonia. If instead we want to personalize the model for a particular speaker, then the utterances are contributed by that person. The larger this corpus is, the better the model is likely to be at correctly converting to fluent speech. Using this second smaller and personalized parallel corpus, we run the neural-training algorithm, updating the parameters of the pre-trained base model to generate the final personalized model.
We found that training the model with a multitask objective to predict the target phonemes while simultaneously generating spectrograms of the target speech led to significant quality improvements. Such a multitask trained encoder can be thought of as learning a latent representation of the input that maintains information about the underlying linguistic content.
Overview of the Parrotron model architecture. An input speech spectrogram is passed through encoder and decoder neural networks to generate an output spectrogram in a new voice.
Case Studies To demonstrate a proof of concept, we worked with our fellow Google research scientist and mathematician Dimitri Kanevsky, who was born in Russia to Russian speaking, normal-hearing parents but has been profoundly deaf from a very young age. He learned to speak English as a teenager, by using Russian phonetic representations of English words, learning to pronounce English using transliteration into Russian (e.g., The quick brown fox jumps over the lazy dog => ЗИ КВИК БРАУН ДОГ ЖАМПС ОУВЕР ЛАЙЗИ ДОГ). As a result, Dimitri’s speech is substantially distinct from native English speakers, and can be challenging to comprehend for systems or listeners who are not accustomed to it.
Dimitri recorded a corpus of 15 hours of speech, which was used to adapt the base model to the nuances specific to his speech. The resulting Parrotron system helped him be better understood by both people and Google’s ASR system alike. Running Google’s ASR engine on the output of Parrotron significantly reduced the word error rate from 89% to 32%, on a held out test set from Dimitri. Below is an example of Parrotron’s successful conversion of input speech from Dimitri:
We also worked with Aubrie Lee, a Googler and advocate for disability inclusion, who has muscular dystrophy, a condition that causes progressive muscle weakness, and sometimes impacts speech production. Aubrie contributed 1.5 hours of speech, which has been instrumental in showing promising outcomes of the applicability of this speech-to-speech technology. Below is an example of Parrotron’s successful conversion of input speech from Aubrie:
We also tested Parrotron’s performance on speech from speakers with ALS by adapting the pretrained model on multiple speakers who share similar speech characteristics grouped together, rather than on a single speaker. We conducted a preliminary listening study and observed an increase in intelligibility when comparing natural ALS speech to the corresponding speech obtained from running the Parroton model, for the majority of our test speakers.
Cascaded Approach Project Euphonia has built a personalized speech-to-text model that has reduced the word error rate for a deaf speaker from 89% to 25%, and ongoing research is also likely to improve upon these results. One could use such a speech-to-text model to achieve a similar goal as Parrotron by simply passing its output into a TTS system to synthesize speech from the result. In such a cascaded approach, however, the recognizer may choose an incorrect word (roughly 1 out 4 times, in this case)—i.e., it may yield words/sentences with unintended meaning and, as a result, the synthesized audio of these words would be far from the speaker’s intention. Given the end-to-end speech-to-speech training objective function of Parrotron, even when errors are made, the generated output speech is likely to sound acoustically similar to the input speech, and thus the speaker’s original intention is less likely to be significantly altered and it is often still possible to understand what is intended:
Furthermore, since Parrotron is not strongly biased to producing words from a predefined vocabulary set, input to the model may contain completely new invented words, foreign words/names, and even nonsense words. We observe that feeding Arabic and Spanish utterances into the US-English Parrotron model often results in output which echoes the original speech content with an American accent, in the target voice. Such behavior is qualitatively different from what one would obtain by simply running an ASR followed by a TTS. Finally, by going from a combination of independently tuned neural networks to a single one, we also believe there are improvements and simplifications that could be substantial.
Conclusion Parrotron makes it easier for users with atypical speech to talk to and be understood by other people and by speech interfaces, with its end-to-end speech conversion approach more likely to reproduce the user’s intended speech. More exciting applications of Parrotron are discussed in our paper and additional audio samples can be found on our github repository. If you would like to participate in this ongoing research, please fill out this short form and volunteer to record a set of phrases. We look forward to working with you! Acknowledgements This project was joint work between the Speech and Google Brain teams. Contributors include Fadi Biadsy, Ron Weiss, Pedro Moreno, Dimitri Kanevsky, Ye Jia, Suzan Schwartz, Landis Baker, Zelin Wu, Johan Schalkwyk, Yonghui Wu, Zhifeng Chen, Patrick Nguyen, Aubrie Lee, Andrew Rosenberg, Bhuvana Ramabhadran, Jason Pelecanos, Julie Cattiau, Michael Brenner, Dotan Emanuel and Joel Shor. Our data collection efforts have been vastly accelerated by our collaborations with ALS-TDI.
Posted by Yinfei Yang and Amin Ahmad, Software Engineers, Google Research
Since it was introduced last year, “Universal Sentence Encoder (USE) for English’’ has become one of the most downloaded pre-trained text modules in Tensorflow Hub, providing versatile sentence embedding models that convert sentences into vector representations. These vectors capture rich semantic information that can be used to train classifiers for a broad range of downstream tasks. For example, a strong sentiment classifier can be trained from as few as one hundred labeled examples, and still be used to measure semantic similarity and for meaning-based clustering.
Today, we are pleased to announce the release of three new USE multilingual modules with additional features and potential applications. The first two modules provide multilingual models for retrieving semantically similar text, one optimized for retrieval performance and the other for speed and less memory usage. The third model is specialized for question-answer retrieval in sixteen languages (USE-QA), and represents an entirely new application of USE. All three multilingual modules are trained using a multi-task dual-encoder framework, similar to the original USE model for English, while using techniques we developed for improving the dual-encoder with additive margin softmax approach. They are designed not only to maintain good transfer learning performance, but to perform well on semantic retrieval tasks.
Multi-task training structure of the Universal Sentence Encoder. A variety of tasks and task structures are joined by shared encoder layers/parameters (pink boxes).
Semantic Retrieval Applications The three new modules are all built on semantic retrieval architectures, which typically split the encoding of questions and answers into separate neural networks, which makes it possible to search among billions of potential answers within milliseconds. The key to using dual encoders for efficient semantic retrieval is to pre-encode all candidate answers to expected input queries and store them in a vector database that is optimized for solving the nearest neighbor problem, which allows a large number of candidates to be searched quickly with good precision and recall. For all three modules, the input query is then encoded into a vector on which we can perform an approximate nearest neighbor search. Together, this enables good results to be found quickly without needing to do a direct query/candidate comparison for every candidate. The prototypical pipeline is illustrated below:
A prototypical semantic retrieval pipeline, used for textual similarity.
Semantic Similarity Modules For semantic similarity tasks, the query and candidates are encoded using the same neural network. Two common semantic retrieval tasks made possible by the new modules include Multilingual Semantic Textual Similarity Retrieval and Multilingual Translation Pair Retrieval.
Multilingual Semantic Textual Similarity Retrieval Most existing approaches for finding semantically similar text require being given a pair of texts to compare. However, using the Universal Sentence Encoder, semantically similar text can be extracted directly from a very large database. For example, in an application like FAQ search, a system can first index all possible questions with associated answers. Then, given a user’s question, the system can search for known questions that are semantically similar enough to provide an answer. A similar approach was used to find comparable sentences from 50 million sentences in wikipedia. With the new multilingual USE models, this can be done in any of supported non-English languages.
Multilingual Translation Pair Retrieval The newly released modules can also be used to mine translation pairs to train neural machine translation systems. Given a source sentence in one language (“How do I get to the restroom?”), they can find the potential translation target in any other supported language (“¿Cómo llego al baño?”).
Both new semantic similarity modules are cross-lingual. Given an input in Chinese, for example, the modules can find the best candidates, regardless of which language it is expressed in. This versatility can be particularly useful for languages that are underrepresented on the internet. For example, an early version of these modules has been used by Chidambaram et al. (2018) to provide classifications in circumstances where the training data is only available in a single language, e.g. English, but the end system must function in a range of other languages.
USE for Question-Answer Retrieval The USE-QA module extends the USE architecture to question-answer retrieval applications, which generally take an input query and find relevant answers from a large set of documents that may be indexed at the document, paragraph, or even sentence level. The input query is encoded with the question encoding network, while the candidates are encoded with the answer encoding network.
Visualizing the action of a neural answer retrieval system. The blue point at the north pole represents the question vector. The other points represent the embeddings of various answers. The correct answer, highlighted here in red, is “closest” to the question, in that it minimizes the angular distance. The points in this diagram are produced by an actual USE-QA model, however, they have been projected downwards from ℝ500 to ℝ3 to assist the reader’s visualization.
Question-answer retrieval systems also rely on the ability to understand semantics. For example, consider a possible query to one such system, Google Talk to Books, which was launched in early 2018 and backed by a sentence-level index of over 100,000 books. A query, “What fragrance brings back memories?”, yields the result, “And for me, the smell of jasmine along with the pan bagnat, it brings back my entire carefree childhood.” Without specifying any explicit rules or substitutions, the vector encoding captures the semantic similarity between the terms fragrance and smell. The advantage provided by the USE-QA module is that it can extend question-answer retrieval tasks such as this to multilingual applications.
For Researchers and Developers We're pleased to share the latest additions to the Universal Sentence Encoder family with the research community, and are excited to see what other applications will be found. These modules can be used as-is, or fine tuned using domain-specific data. Lastly, we will also host the semantic similarity for natural language page on Cloud AI Workshop to further encourage research in this area.
Acknowledgements Mandy Guo, Daniel Cer, Noah Constant, Jax Law, Muthuraman Chidambaram for core modeling, Gustavo Hernandez Abrego, Chen Chen, Mario Guajardo-Cespedes for infrastructure and colabs, Steve Yuan, Chris Tar, Yunhsuan Sung, Brian Strope, Ray Kurzweil for discussion of the model architecture.
Posted by Qizhe Xie, Student Researcher and Thang Luong, Senior Research Scientist, Google Research, Brain Team
Success in deep learning has largely been enabled by key factors such as algorithmic advancements, parallel processing hardware (GPU / TPU), and the availability of large-scale labeled datasets, like ImageNet. However, when labeled data is scarce, it can be difficult to train neural networks to perform well. In this case, one can apply data augmentation methods, e.g., paraphrasing a sentence or rotating an image, to effectively increase the amount of labeled training data. Recently, there has been significant progress in the design of data augmentation approaches for a variety of areas such as natural language processing (NLP), vision, and speech. Unfortunately, data augmentation is often limited to supervised learning only, in which labels are required to transfer from original examples to augmented ones.
Example augmentation operations for text-based (top) or image-based (bottom) training data.
Unsupervised Data Augmentation Explained Unsupervised Data Augmentation (UDA) makes use of both labeled data and unlabeled data. To use labeled data, it computes the loss function using standard methods for supervised learning to train the model, as shown in the left part of the graph below. For unlabeled data, consistency training is applied to enforce the predictions to be similar for an unlabeled example and the augmented unlabeled example, as shown in the right part of the graph. Here, the same model is applied to both the unlabeled example and its augmented counterpart to produce two model predictions, from which a consistency loss is computed (i.e., the distance between the two prediction distributions). UDA then computes the final loss by jointly optimizing both the supervised loss from the labeled data and the unsupervised consistency loss from the unlabeled data.
An overview of Unsupervised Data Augmentation (UDA). Left: Standard supervised loss is computed when labeled data is available. Right: With unlabeled data, a consistency loss is computed between an example and its augmented version.
By minimizing the consistency loss, UDA allows for label information to propagate smoothly from labeled examples to unlabeled ones. Intuitively, one can think of UDA as an implicit iterative process. First, the model relies on a small amount of labeled examples to make correct predictions for some unlabeled examples, from which the label information is propagated to augmented counterparts through the consistency loss. Over time, more and more unlabeled examples will be predicted correctly which reflects the improved generalization of the model. Various other types of noise have been tested for consistency training (e.g., Gaussian noise, adversarial noise, and others), yet we found that data augmentation outperforms all of them, leading to state-of-the-art performance on a wide variety of tasks from language to vision. UDA applies different existing augmentation methods depending on the task at hand, including back translation, AutoAugment, and TF-IDF word replacement.
Benchmarks in NLP and Computer Vision UDA is surprisingly effective in the low-data regime. With only 20 labeled examples, UDA achieves an error rate of 4.20 on the IMDb sentiment analysis task by leveraging 50,000 unlabeled examples. This result outperforms the previous state-of-the-art model trained on 25,000 labeled examples with an error rate of 4.32. In the large-data regime, with the full training set, UDA also provides robust gains.
Benchmark on IMDb, a sentiment analysis task. UDA surpasses state-of-the-art results in supervised learning across different training sizes.
On the CIFAR-10 semi-supervised learning benchmark, UDA outperforms all existing SSL methods, such as VAT, ICT, and MixMatch by significant margins. With 4k examples, UDA achieves an error rate of 5.27, matching the performance of the fully supervised model that uses 50k examples. Furthermore, with a more advanced architecture, PyramidNet+ShakeDrop, UDA achieves a new state-of-the-art error rate of 2.7, a more than 45% reduction in error rate compared to the previous best semi-supervised result. On SVHN, UDA achieves an error rate of 2.85 with only 250 labeled examples, matching the performance of the fully supervised model trained with ~70k labeled examples.
SSL benchmark on CIFAR-10, an image classification task. UDA surpases all existing semi-supervised learning methods, all of which use the Wide-ResNet-28-2 architecture. At 4000 examples, UDA matches the performance of the fully supervised setting with 50,000 examples.
On ImageNet with 10% labeled examples, UDA improves the top-1 accuracy from 55.1% to 68.7%. In the high-data regime with the fully labeled set and 1.3M extra unlabeled examples, UDA continues to provide gains from 78.3% to 79.0% for top-1 accuracy.
Release We have released the codebase of UDA, together with all data augmentation methods, e.g., back-translation with pre-trained translation models, to replicate our results. We hope that this release will further advance the progress in semi-supervised learning.
Acknowledgements Special thanks to the co-authors of the paper Zihang Dai, Eduard Hovy, and Quoc V. Le. We’d also like to thank Hieu Pham, Adams Wei Yu, Zhilin Yang, Colin Raffel, Olga Wichrowska, Ekin Dogus Cubuk, Guokun Lai, Jiateng Xie, Yulun Du, Trieu Trinh, Ran Zhao, Ola Spyra, Brandon Yang, Daiyi Peng, Andrew Dai, Samy Bengio and Jeff Dean for their help with this project. A preprint is available online.
Deepneural networks (DNN) are the cornerstone of recent progress in machine learning, and are responsible for recent breakthroughs in a variety of tasks such as image recognition, image segmentation, machine translation and more. However, despite their ubiquity, researchers are still attempting to fully understand the underlying principles that govern them. In particular, classical theories (e.g., VC-dimension and Rademacher complexity) suggest that over-parameterized functions should generalize poorly to unseen data, yet recent work has found that massively over-parameterized functions (orders of magnitude more parameters than the number of data points) generalize well. In order to improve models, a better understanding of generalization, which can lead to more theoretically grounded and therefore more principled approaches to DNN design, is required.
An important concept for understanding generalization is the generalization gap, i.e., the difference between a model’s performance on training data and its performance on unseen data drawn from the same distribution. Significant strides have been made towards deriving better DNN generalization bounds—the upper limit to the generalization gap—but they still tend to greatly overestimate the actual generalization gap, rendering them uninformative as to why some models generalize so well. On the other hand, the notion of margin—the distance between a data point and the decision boundary—has been extensively studied in the context of shallow models such as support-vector machines, and is found to be closely related to how well these models generalize to unseen data. Because of this, the use of margin to study generalization performance has been extended to DNNs, resulting in highly refined theoretical upper bounds on the generalization gap, but has not significantly improved the ability to predict how well a model generalizes.
An example of a support-vector machine decision boundary. The hyperplane defined by w∙x-b=0 is the "decision boundary" of this linear classifier, i.e., every point x lying on the hyperplane is equally likely to be in either class under this classifier.
In our ICLR 2019 paper, “Predicting the Generalization Gap in Deep Networks with Margin Distributions”, we propose the use of a normalized margin distribution across network layers as a predictor of the generalization gap. We empirically study the relationship between the margin distribution and generalization and show that, after proper normalization of the distances, some basic statistics of the margin distributions can accurately predict the generalization gap. We also make available all the models used as a dataset for studying generalization through the Github repository.
Each plot corresponds to a convolutional neural network trained on CIFAR-10 with different classification accuracies. The probability density (y-axis) of normalized margin distributions (x-axis) at 4 layers of a network is shown for three different models with increasingly better generalization (left to right). The normalized margin distributions are strongly correlated with test accuracy, which suggests they can be used as a proxy for predicting a network's generalization gap. Please see our paper for more details on these networks.
Margin Distributions as a Predictor of Generalization Intuitively, if the statistics of the margin distribution are truly predictive of the generalization performance, a simple prediction scheme should be able to establish the relationship. As such, we chose linear regression to be the predictor. We found that the relationship between the generalization gap and the log-transformed statistics of the margin distributions is almost perfectly linear (see figure below). In fact, the proposed scheme produces better prediction relative to other existing measures of generalization. This indicates that the margin distributions may contain important information about how deep models generalize.
Predicted generalization gap (x-axis) vs. true generalization gap (y-axis) on CIFAR-100 + ResNet-32. The points lie close to the diagonal line, which indicates that the predicted values of the log linear model fit the true generalization gap very well.
The Deep Model Generalization Dataset In addition to our paper, we are introducing the Deep Model Generalization (DEMOGEN) dataset, which consists of of 756 trained deep models, along with their training and test performance on the CIFAR-10 and CIFAR-100 datasets. The models are variants of CNNs (with architectures that resemble Network-in-Network) and ResNet-32 with different popular regularization techniques and hyperparameter settings, inducing a wide spectrum of generalization behaviors. For example, the models of CNNs trained on CIFAR-10 have the test accuracies ranging from 60% to 90.5% with generalization gaps ranging from 1% to 35%. For details of the dataset, please see our paper or the Github repository. As part of the dataset release, we also include utilities to easily load the models and reproduce the results presented in our paper.
We hope that this research and the DEMOGEN dataset will provide the community with an accessible tool for studying generalization in deep learning without having to retrain a large number of models. We also hope that our findings will motivate further research in generalization gap predictors and margin distributions in the hidden layers.
Posted by Joonseok Lee and Joe Yue-Hei Ng, Software Engineers, Google Research
Over the last two years, the First and Second YouTube-8M Large-Scale Video Understanding Challenge and Workshop have collectively drawn 1000+ teams from 60+ countries to further advance large-scale video understanding research. While these events have enabled great progress in video classification, the YouTube dataset on which they were based only used machine-generated video-level labels, and lacked fine-grained temporally localized information, which limited the ability of machine learning models to predict video content.
YouTube-8M Segments Video segment labels provide a valuable resource for temporal localization not possible with video-level labels, and enable novel applications, such as capturing special video moments. Instead of exhaustively labeling all segments in a video, to create the YouTube-8M Segments extension, we manually labeled 5 segments (on average) per randomly selected video on the YouTube-8M validation dataset, totalling ~237k segments covering 1000 categories.
This dataset, combined with the previous YouTube-8M release containing a very large number of machine generated video-level labels, should allow learning temporal localization models in novel ways. Evaluating such classifiers is of course very challenging if only noisy video-level labels are available. We hope that the newly added human-labeled annotations will help ensure that researchers can more accurately evaluate their algorithms.
The 3rd YouTube-8M Video Understanding Challenge This year the YouTube-8M Video Understanding Challenge focuses on temporal localization. Participants are encouraged to leverage noisy video-level labels together with a small segment-level validation set in order to better annotate and temporally localize concepts of interest. Unlike last year, there is no model size restriction. Each of the top 10 teams will be awarded $2,500 to support their travel to Seoul to attend ICCV’19. For details, please visit the Kaggle competition page.
The 3rd Workshop on YouTube-8M Large-Scale Video Understanding Continuing in the tradition of the previous two years, the 3rd workshop will feature four invited talks by distinguished researchers as well as presentations by top-performing challenge participants. We encourage those who wish to attend to submit papers describing their research, experiments, or applications based on the YouTube-8M dataset, including papers summarizing their participation in the challenge above. Please refer to the workshop page for more details.
It is our hope that this newest extension will serve as a unique playground for temporal localization that mimics real world scenarios. We also look forward to the new challenge and workshop, which we believe will continue to advance research in large-scale video understanding. We hope you will join us again!
Acknowledgements This post reflects the work of many machine perception researchers including Ke Chen, Nisarg Kothari, Joonseok Lee, Hanhan Li, Paul Natsev, Joe Yue-Hei Ng, Naderi Parizi, David Ross, Cordelia Schmid, Javier Snaider, Rahul Sukthankar, George Toderici, Balakrishnan Varadarajan, Sudheendra Vijayanarasimhan, Yexin Wang, Zheng Xu, as well as Julia Elliott and Walter Reade from Kaggle. We are also grateful for the support and advice from our partners at YouTube.
Posted by Alex Fabrikant, Research Scientist, Google Research
Hundreds of millions of people across the world rely on public transit for their daily commute, and over half of the world's transit trips involve buses. As the world's cities continue growing, commuters want to know when to expect delays, especially for bus rides, which are prone to getting held up by traffic. While public transit directions provided by Google Maps are informed by many transit agencies that provide real-time data, there are many agencies that can’t provide them due to technical and resource constraints.
Today, Google Maps introduced live traffic delays for buses, forecasting bus delays in hundreds of cities world-wide, ranging from Atlanta to Zagreb to Istanbul to Manila and more. This improves the accuracy of transit timing for over sixty million people. This system, first launched in India three weeks ago, is driven by a machine learning model that combines real-time car traffic forecasts with data on bus routes and stops to better predict how long a bus trip will take.
The Beginnings of a Model In the many cities without real-time forecasts from the transit agency, we heard from surveyed users that they employed a clever workaround to roughly estimate bus delays: using Google Maps driving directions. But buses are not just large cars. They stop at bus stops; take longer to accelerate, slow down, and turn; and sometimes even have special road privileges, like bus-only lanes.
As an example, let’s examine a Wednesday afternoon bus ride in Sydney. The actual motion of the bus (blue) is running a few minutes behind the published schedule (black). Car traffic speeds (red) do affect the bus, such as the slowdown at 2000 meters, but a long stop at the 800 meter mark slows the bus down significantly compared to a car.
To develop our model, we extracted training data from sequences of bus positions over time, as received from transit agencies’ real time feeds, and aligned them to car traffic speeds on the bus's path during the trip. The model is split into a sequence of timeline units—visits to street blocks and stops—each corresponding to a piece of the bus's timeline, with each unit forecasting a duration. A pair of adjacent observations usually spans many units, due to infrequent reporting, fast-moving buses, and short blocks and stops.
This structure is well suited for neural sequence models like those that have recently been successfully applied to speech processing, machine translation, etc. Our model is simpler. Each unit predicts its duration independently, and the final output is the sum of the per-unit forecasts. Unlike many sequence models, our model does not need to learn to combine unit outputs, nor to pass state through the unit sequence. Instead, the sequence structure lets us jointly (1) train models of individual units' durations and (2) optimize the "linear system" where each observed trajectory assigns a total duration to the sum of the many units it spans.
To model a bus trip (a) starting at the blue stop, the model (b) adds up the delay predictions from timeline units for the blue stop, the three road segments, the white stop, etc.
Modeling the "Where" In addition to road traffic delays, in training our model we also take into account details about the bus route, as well as signals about the trip's location and timing. Even within a small neighborhood, the model needs to translate car speed predictions into bus speeds differently on different streets. In the left panel below, we color-code our model's predicted ratio between car speeds and bus speeds for a bus trip. Redder, slower parts may correspond to bus deceleration near stops. As for the fast green stretch in the highlighted box, we learn from looking at it in StreetView (right) that our model discovered a bus-only turn lane. By the way, this route is in Australia, where right turns are slower than left, another aspect that would be lost on a model that doesn’t consider peculiarities of location.
To capture unique properties of specific streets, neighborhoods, and cities, we let the model learn a hierarchy of representations for areas of different size, with a timeline unit's geography (the precise location of a road or a stop) represented in the model by the sum of the embeddings of its location at various scales. We first train the model with progressively heavier penalties for finer-grain locations with special cases, and use the results for feature selection. This ensures that fine-grained features in areas complex enough where a hundred meters affects bus behavior are taken into account, as opposed to open countryside where such fine-grained features seldom matter.
At training time, we also simulate the possibility of later queries about areas that were not in the training data. In each training batch, we take a random slice of examples and discard geographic features below a scale randomly selected for each. Some examples are kept with the exact bus route and street, others keep only neighborhood- or city-level locations, and others yet have no geographical context at all. This better prepares the model for later queries about areas where we were short on training data. We expand the coverage of our training corpus by using anonymized inferences about user bus trips from the same dataset that Google Maps uses for popular times at businesses, parking difficulty, and other features. However, even this data does not include the majority of the world's bus routes, so our models must generalize robustly to new areas.
Learning the Local Rhythms Different cities and neighborhoods also run to a different beat, so we allow the model to combine its representation of location with time signals. Buses have a complex dependence on time — the difference between 6:30pm and 6:45pm on a Tuesday might be the wind-down of rush hour in some neighborhoods, a busy dining time in others, and entirely quiet in a sleepy town elsewhere. Our model learns an embedding of the local time of day and day of week signals, which, when combined with the location representation, captures salient local variations, like rush hour bus stop crowds, that aren't observed via car traffic.
This embedding assigns 4-dimensional vectors to times of the day. Unlike most neural net internals, four dimensions is almost few enough to visualize, so let's peek at how the model arranges times of day in three of those dimensions, via the artistic rendering below. The model indeed learns that time is cyclical, placing time in a "loop". But this loop is not just the flat circle of a clock's face. The model learns wide bends that let other neurons compose simple rules to easily separate away concepts like "middle of the night" or "late morning" that don't feature much bus behavior variation. On the other hand, evening commute patterns differ much more among neighborhoods and cities, and the model appears to create more complex "crumpled" patterns between 4pm-9pm that enable more intricate inferences about the timings of each city's rush hour.
The model's time representation (3 out of 4 dimensions) forms a loop, reimagined here as the circumference of a watch. The more location-dependent time windows like 4pm-9pm and 7am-9am get more complex "crumpling", while big featureless windows like 2am-5am get bent away with flat bends for simpler rules. (Artist's conception by Will Cassella, using textures from textures.com and HDRIs from hdrihaven.)
Together with other signals, this time representation lets us predict complex patterns even if we hold car speeds constant. On a 10km bus ride through New Jersey, for example, our model picks up on lunchtime crowds and weekday rush hours:
Putting it All Together With the model fully trained, let's take a look at what it learned about the Sydney bus ride above. If we run the model on that day's car traffic data, it gives us the green predictions below. It doesn't catch everything. For instance, it has the stop at 800 meters lasting only 10 seconds, though the bus stopped for at least 31 sec. But we stay within 1.5 minutes of the real bus motion, catching a lot more of the trip's nuances than the schedule or car driving times alone would give us.
The Trip Ahead One thing not in our model for now? The bus schedule itself. So far, in experiments with official agency bus schedules, they haven't improved our forecasts significantly. In some cities, severe traffic fluctuations might overwhelm attempts to plan a schedule. In others, the bus schedules might be precise, but perhaps because transit agencies carefully account for traffic patterns. And we infer those from the data.
We continue to experiment with making better use of schedule constraints and many other signals to drive more precise forecasting and make it easier for our users to plan their trips. We hope we'll be of use to you on your way, too. Happy travels!
Acknowledgements This work was the joint effort of James Cook, Alex Fabrikant, Ivan Kuznetsov, and Fangzhou Xu, on Google Research, and Anthony Bertuca, Julian Gibbons, Thierry Le Boulengé, Cayden Meyer, Anatoli Plotnikov, and Ivan Volosyuk on Google Maps. We thank Senaka Buthpitiya, Da-Cheng Juan, Reuben Kan, Ramesh Nagarajan, Andrew Tomkins, and the greater Transit team for support and helpful discussions; as well as Will Cassella for the inspired reimagining of the model's time embedding. We are also indebted to our partner agencies for providing the transit data feeds the system is trained on.
Posted by Alessandro Epasto, Senior Research Scientist and Bryan Perozzi, Senior Research Scientist, Graph Mining Team
Relational data representing relationships between entities is ubiquitous on the Web (e.g., online social networks) and in the physical world (e.g., in protein interaction networks). Such data can be represented as a graph with nodes (e.g., users, proteins), and edges connecting them (e.g., friendship relations, protein interactions). Given the widespread prevalence of graphs, graph analysis plays a fundamental role in machine learning, with applications in clustering, link prediction, privacy, and others. To apply machine learning methods to graphs (e.g., predicting new friendships, or discovering unknown protein interactions) one needs to learn a representation of the graph that is amenable to be used in ML algorithms.
However, graphs are inherently combinatorial structures made of discrete parts like nodes and edges, while many common ML methods, like neural networks, favor continuous structures, in particular vector representations.Vector representations are particularly important in neural networks, as they can be directly used as input layers. To get around the difficulties in using discrete graph representations in ML,graph embedding methods learn a continuous vector space for the graph, assigning each node (and/or edge) in the graph to a specific position in a vector space. A popular approach in this area is that of random-walk-based representation learning, as introduced in DeepWalk.
Left: The well-known Karate graph representing a social network. Right: A continuous space embedding of the nodes in the graph using DeepWalk.
Learning Node Representations that Capture Multiple Social Contexts In virtually all cases, the crucial assumption of standard graph embedding methods is that a single embedding has to be learned for each node. Thus, the embedding method can be said to seek to identify the single role or position that characterizes each node in the geometry of the graph. Recent work observed, however, that nodes in real networks belong to multiple overlapping communities and play multiple roles—think about your social network where you participate in both your family and in your work community. This observation motivates the following research question: is it possible to develop methods where nodes are embedded in multiple vectors, representing their participation in overlapping communities?
In our WWW’19 paper, we developed Splitter, an unsupervised embedding method that allows the nodes in a graph to have multiple embeddings to better encode their participation in multiple communities. Our method is based on recent innovations in overlapping clustering based on ego-network analysis, using the persona graph concept, in particular. This method takes a graph G, and creates a new graph P (called the persona graph), where each node in G is represented by a series of replicas called the persona nodes. Each persona of a node represents an instantiation of the node in a local community to which it belongs. For each node U in the graph, we analyze the ego-network of the node (i.e., the graph connecting the node to its neighbors, in this example A, B, C, D) to discover local communities to which the node belongs. For instance, in the figure below, node U belongs to two communities: Cluster 1 (with the friends A and B, say U’s family members) and Cluster 2 (with C and D, say U’s colleagues).
Ego-net of node U
Then, we use this information to “split” node U into its two personas U1 (the family persona) and U2 (the work persona). This disentangles the two communities, so that they no longer overlap.
The ego-splitting method separating the U nodes in 2 personas.
This technique has been used to improve the state-of-the-art results in graph embedding methods, showing up to 90% reduction in link prediction (i.e., predicting which link will form in the future) error on a variety of graphs. The key reason for this improvement is the ability of the method to disambiguate highly overlapping communities found in social networks and other real-world graphs. We further validate this result with an in-depth analysis of co-authorship graphs where authors belong to overlapping research communities (e.g., machine learning and data mining).
Top Left: A typical graphs with highly overlapping communities. Top Right: A traditional embedding of the graph on the left using node2vec. Bottom Left: A persona graph of the graph above. Bottom Right: The Splitter embedding of the persona graph. Notice how the persona graph clearly disentangles the overlapping communities of the original graph and Splitter outputs well-separated embeddings.
Automatic hyper-parameter tuning via graph attention. Graph embedding methods have shown outstanding performance on various ML-based applications, such as link prediction and node classification, but they have a number of hyper-parameters that must be manually set. For example, are nearby nodes more important to capture when learning embeddings than nodes that are further away? Even though experts may be able to fine tune these hyper-parameters, one must do so independently for each graph. To obviate such manual work, in our second paper, we proposed a method to learn the optimal hyper-parameters automatically.
Specifically, many graph embedding methods, like DeepWalk, employ random walks to explore the context around a given node (i.e. the direct neighbors, the neighbors of the neighbors, etc). Such random walks can have many hyper-parameters that allow tuning of the local exploration of the graph, thus regulating the attention given by the embeddings to nearby nodes. Different graphs may present different optimal attention patterns and hence different optimal hyperparameters (see the picture below, where we show two different attention distributions). Watch Your Step formulates a model for the performance of the embedding methods based on the above mentioned hyper-parameters. Then we optimize the hyper-parameters to maximize the performance predicted by the model, using standard backpropagation. We found that the values learned by backpropagation agree with the optimal hyper-parameters obtained by grid search.
Our new method for automatic hyper-parameter tuning, Watch Your Step, uses an attention model to learn different graph context distributions. Shown above are two example local neighborhoods about a center node (in yellow) and the context distributions (red gradient) that was learned by the model. The left-side graph shows a more diffused attention model, while the distribution on the right shows one concentrated on direct neighbors.
This work falls under the growing family of AutoML, where we want to alleviate the burden of optimizing the hyperparameters—a common problem in practical machine learning. Many AutoML methods use neural architecture search. This paper instead shows a variant, where we use the mathematical connection between the hyperparameters in the embeddings and graph-theoretic matrix formulations. The “Auto” portion corresponds to learning the graph hyperparameters by backpropagation.
We believe that our contributions will further advance the state of the research in graph embedding in various directions. Our method for learning multiple node embeddings draws a connection between the rich and well-studied field of overlapping community detection, and the more recent one of graph embedding which we believe may result in fruitful future research. An open problem in this area is the use of multiple-embedding methods for classification. Furthermore, our contribution on learning hyperparameters will foster graph embedding adoption by reducing the need for expensive manual tuning. We hope the release of these papers and code will help the research community pursue these directions.
Acknowledgements We thank Sami Abu-el-Haija who contributed to this work and is now a Ph.D. student at USC.
Posted by Alex Irpan, Software Engineer, Robotics at Google
Reinforcement learning (RL) is a framework that lets agents learn decision making from experience. One of the many variants of RL is off-policy RL, where an agent is trained using a combination of data collected by other agents (off-policy data) and data it collects itself to learn generalizable skills like robotic walking and grasping. In contrast, fully off-policy RL is a variant in which an agent learns entirely from older data, which is appealing because it enables model iteration without requiring a physical robot. With fully off-policy RL, one can train several models on the same fixed dataset collected by previous agents, then select the best one. However, fully off-policy RL comes with a catch: while training can occur without a real robot, evaluation of the models cannot. Furthermore, ground-truth evaluation with a physical robot is too inefficient to test promising approaches that require evaluating a large number of models, such as automated architecture search with AutoML.
This challenge motivates off-policy evaluation (OPE), techniques for studying the quality of new agents using data from other agents. With rankings from OPE, we can selectively test only the most promising models on real-world robots, significantly scaling experimentation with the same fixed real robot budget.
A diagram for real-world model development. Assuming we can evaluate 10 models per day, without off-policy evaluation, we would need 100x as many days to evaluate our models.
Though the OPE framework shows promise, it assumes one has an off-policy evaluation method that accurately ranks performance from old data. However, agents that collected past experience may act very differently from newer learned agents, which makes it hard to get good estimates of performance.
In “Off-Policy Evaluation via Off-Policy Classification”, we propose a new off-policy evaluation method, called off-policy classification (OPC), that evaluates the performance of agents from past data by treating evaluation as a classification problem, in which actions are labeled as either potentially leading to success or guaranteed to result in failure. Our method works for image (camera) inputs, and doesn’t require reweighting data with importance sampling or using accurate models of the target environment, two approaches commonly used in prior work. We show that OPC scales to larger tasks, including a vision-based robotic grasping task in the real world.
How OPC Works OPC relies on two assumptions: 1) that the final task has deterministic dynamics, i.e. no randomness is involved in how states change, and 2) that the agent either succeeds or fails at the end of each trial. This second “success or failure” assumption is natural for many tasks, such as picking up an object, solving a maze, winning a game, and so on. Because each trial will either succeed or fail in a deterministic way, we can assign binary classification labels to each action. We say an action is effective if it could lead to success, and is catastrophic if it is guaranteed to lead to failure.
OPC utilizes a Q-function, learned with a Q-learning algorithm, that estimates the future total reward if the agent chooses to take some action from its current state. The agent will then choose the action with the largest total reward estimate. In our paper, we prove that the performance of an agent is measured by how often its chosen action is an effective action, which depends on how well the Q-function correctly classifies actions as effective vs. catastrophic. This classification accuracy acts as an off-policy evaluation score.
However, the labeling of data from previous trials is only partial. For example, if a previous trial was a failure, we do not get negative labels because we do not know which action was the catastrophic one. To overcome this, we leverage techniques from semi-supervised learning, positive-unlabeled learning in particular, to get an estimate of classification accuracy from partially labeled data. This accuracy is the OPC score.
Off-Policy Evaluation for Sim-to-Real Learning In robotics, it’s common to use simulated data and transfer learning techniques to reduce the sample complexity of learning robotics skills. This can be very useful, but tuning these sim-to-real techniques for real-world robotics is challenging. Much like off-policy RL, training doesn’t use the real robot, because it is trained in simulation, but evaluation of that policy still needs to use a real robot. Here, off-policy evaluation can come to the rescue again—we can take a policy trained only in simulation, then evaluate it using previous real-world data to measure its transfer to the real robot. We examine OPC across both fully off-policy RL and sim-to-real RL.
An example of how simulated experience can differ from real-world experience. Here, simulated images (left) have much less visual complexity than real-world images (right).
Results First, we set up a simulated version of our robot grasping task, where we could easily train and evaluate several models to benchmark off-policy evaluation. These models were trained with fully off-policy RL, then evaluated with off-policy evaluation. We found that in our robotics tasks, a variant of the OPC called the SoftOPC performed best at predicting final success rate.
An experiment in the simulated grasping task. The red curve is the dimensionless SoftOPC score over the course of training, evaluated from old data. The blue curve is the grasp success rate in simulation. We see the SoftOPC on old data correlates well with grasp success of the model within our simulator.
After success in sim, we then tried SoftOPC in the real-world task. We took 15 models, trained to have varying degrees of robustness to the gap between simulation and reality. Of these models, 7 of them were trained purely in simulation, and the rest were trained on mixes of simulated and real-world data. For each model, we evaluated the SoftOPC on off-policy real-world data, then the real-world grasp success, to see how well SoftOPC predicted performance of that model. We found that on real data, the SoftOPC does produce scores that correlate with true grasp success, letting us rank sim-to-real techniques using past real experience.
SoftOPC score and true performance for 3 different sim-to-real methods: a baseline simulation, a simulation with random textures and lighting, and a model trained with RCAN. All three models are trained with no real data, then evaluated with off-policy evaluation on a validation set of real data. The ordering of the SoftOPC score matches the order of real grasp success.
Below is a scatterplot of the full results from all 15 models. Each point represents the off-policy evaluation score and real-world grasp success of each model. We compare different scoring functions by their correlation to final grasp success. The SoftOPC does not correlate perfectly with true grasp success, but its scores are significantly more reliable than baseline approaches like the temporal-difference error (the standard Q-learning loss).
Results from our sim-to-real evaluation experiment. On the left is a baseline, the temporal difference error of the model. On the right is one of our proposed methods, the SoftOPC. The shaded region is a 95% confidence interval. The correlation is significantly better with SoftOPC.
Future Work One promising direction for future work is to see if we can relax our assumptions about the task, to support tasks where dynamics are more noisy, or where we get partial credit for almost succeeding. However, even with our included assumptions, we think the results are promising enough to be applied to many real-world RL problems.
Acknowledgements This research was conducted by Alex Irpan, Kanishka Rao, Konstantinos Bousmalis, Chris Harris, Julian Ibarz and Sergey Levine. We’d like to thank Razvan Pascanu, Dale Schuurmans, George Tucker and Paul Wohlhart for valuable discussions. A preprint is available on arXiv.