Tag Archives: On-device Learning

On-Device Conversational Modeling with TensorFlow Lite



Earlier this year, we launched Android Wear 2.0 which featured the first "on-device" machine learning technology for smart messaging. This enabled cloud-based technologies like Smart Reply, previously available in Gmail, Inbox and Allo, to be used directly within any application for the first time, including third-party messaging apps, without ever having to connect to the cloud. So you can respond to incoming chat messages on the go, directly from your smartwatch.

Today, we announce TensorFlow Lite, TensorFlow’s lightweight solution for mobile and embedded devices. This framework is optimized for low-latency inference of machine learning models, with a focus on small memory footprint and fast performance. As part of the library, we have also released an on-device conversational model and a demo app that provides an example of a natural language application powered by TensorFlow Lite, in order to make it easier for developers and researchers to build new machine intelligence features powered by on-device inference. This model generates reply suggestions to input conversational chat messages, with efficient inference that can be easily plugged in to your chat application to power on-device conversational intelligence.

The on-device conversational model we have released uses a new ML architecture for training compact neural networks (as well as other machine learning models) based on a joint optimization framework, originally presented in ProjectionNet: Learning Efficient On-Device Deep Networks Using Neural Projections. This architecture can run efficiently on mobile devices with limited computing power and memory, by using efficient “projection” operations that transform any input to a compact bit vector representation — similar inputs are projected to nearby vectors that are dense or sparse depending on type of projection. For example, the messages “hey, how's it going?” and “How's it going buddy?”, might be projected to the same vector representation.

Using this idea, the conversational model combines these efficient operations at low computation and memory footprint. We trained this on-device model end-to-end using an ML framework that jointly trains two types of models — a compact projection model (as described above) combined with a trainer model. The two models are trained in a joint fashion, where the projection model learns from the trainer model — the trainer is characteristic of an expert and modeled using larger and more complex ML architectures, whereas the projection model resembles a student that learns from the expert. During training, we can also stack other techniques such as quantization or distillation to achieve further compression or selectively optimize certain portions of the objective function. Once trained, the smaller projection model is able to be used directly for inference on device.
For inference, the trained projection model is compiled into a set of TensorFlow Lite operations that have been optimized for fast execution on mobile platforms and executed directly on device. The TensorFlow Lite inference graph for the on-device conversational model is shown here.
TensorFlow Lite execution for the On-Device Conversational Model.
The open-source conversational model released today (along with code) was trained end-to-end using the joint ML architecture described above. Today’s release also includes a demo app, so you can easily download and try out one-touch smart replies on your mobile device. The architecture enables easy configuration for model size and prediction quality based on application needs. You can find a list of sample messages where this model does well here. The system can also fall back to suggesting replies from a fixed set that was learned and compiled from popular response intents observed in chat conversations. The underlying model is different from the ones Google uses for Smart Reply responses in its apps1.

Beyond Conversational Models
Interestingly, the ML architecture described above permits flexible choices for the underlying model. We also designed the architecture to be compatible with different machine learning approaches — for example, when used with TensorFlow deep learning, we learn a lightweight neural network (ProjectionNet) for the underlying model, whereas a different architecture (ProjectionGraph) represents the model using a graph framework instead of a neural network.

The joint framework can also be used to train lightweight on-device models for other tasks using different ML modeling architectures. As an example, we derived a ProjectionNet architecture that uses a complex feed-forward or recurrent architecture (like LSTM) for the trainer model coupled with a simple projection architecture comprised of dynamic projection operations and a few, narrow fully-connected layers. The whole architecture is trained end-to-end using backpropagation in TensorFlow and once trained, the compact ProjectionNet is directly used for inference. Using this method, we have successfully trained tiny ProjectionNet models that achieve significant reduction in model sizes (up to several orders of magnitude reduction) and high performance with respect to accuracy on multiple visual and language classification tasks (a few examples here). Similarly, we trained other lightweight models using our graph learning framework, even in semi-supervised settings.
ML architecture for training on-device models: ProjectionNet trained using deep learning (left), and ProjectionGraph trained using graph learning (right).
We will continue to improve and release updated TensorFlow Lite models in open-source. We think that the released model (as well as future models) learned using these ML architectures may be reused for many natural language and computer vision applications or plugged into existing apps for enabling machine intelligence. We hope that the machine learning and natural language processing communities will be able to build on these to address new problems and use-cases we have not yet conceived.

Acknowledgments
Yicheng Fan and Gaurav Nemade contributed immensely to this effort. Special thanks to Rajat Monga, Andre Hentz, Andrew Selle, Sarah Sirajuddin, and Anitha Vijayakumar from the TensorFlow team; Robin Dua, Patrick McGregor, Andrei Broder, Andrew Tomkins and the Google Expander team.



1 The released on-device model was trained to optimize for small size and low latency applications on mobile phones and wearables. Smart Reply predictions in Google apps, however are generated using larger, more complex models. In production systems, we also use multiple classifiers that are trained to detect inappropriate content and apply further filtering and tuning to optimize user experience and quality levels. We recommend that developers using the open-source TensorFlow Lite version also follow such practices for their end applications.

Federated Learning: Collaborative Machine Learning without Centralized Training Data



Standard machine learning approaches require centralizing the training data on one machine or in a datacenter. And Google has built one of the most secure and robust cloud infrastructures for processing this data to make our services better. Now for models trained from user interaction with mobile devices, we're introducing an additional approach: Federated Learning.

Federated Learning enables mobile phones to collaboratively learn a shared prediction model while keeping all the training data on device, decoupling the ability to do machine learning from the need to store the data in the cloud. This goes beyond the use of local models that make predictions on mobile devices (like the Mobile Vision API and On-Device Smart Reply) by bringing model training to the device as well.

It works like this: your device downloads the current model, improves it by learning from data on your phone, and then summarizes the changes as a small focused update. Only this update to the model is sent to the cloud, using encrypted communication, where it is immediately averaged with other user updates to improve the shared model. All the training data remains on your device, and no individual updates are stored in the cloud.
Your phone personalizes the model locally, based on your usage (A). Many users' updates are aggregated (B) to form a consensus change (C) to the shared model, after which the procedure is repeated.
Federated Learning allows for smarter models, lower latency, and less power consumption, all while ensuring privacy. And this approach has another immediate benefit: in addition to providing an update to the shared model, the improved model on your phone can also be used immediately, powering experiences personalized by the way you use your phone.

We're currently testing Federated Learning in Gboard on Android, the Google Keyboard. When Gboard shows a suggested query, your phone locally stores information about the current context and whether you clicked the suggestion. Federated Learning processes that history on-device to suggest improvements to the next iteration of Gboard’s query suggestion model.
To make Federated Learning possible, we had to overcome many algorithmic and technical challenges. In a typical machine learning system, an optimization algorithm like Stochastic Gradient Descent (SGD) runs on a large dataset partitioned homogeneously across servers in the cloud. Such highly iterative algorithms require low-latency, high-throughput connections to the training data. But in the Federated Learning setting, the data is distributed across millions of devices in a highly uneven fashion. In addition, these devices have significantly higher-latency, lower-throughput connections and are only intermittently available for training.

These bandwidth and latency limitations motivate our Federated Averaging algorithm, which can train deep networks using 10-100x less communication compared to a naively federated version of SGD. The key idea is to use the powerful processors in modern mobile devices to compute higher quality updates than simple gradient steps. Since it takes fewer iterations of high-quality updates to produce a good model, training can use much less communication. As upload speeds are typically much slower than download speeds, we also developed a novel way to reduce upload communication costs up to another 100x by compressing updates using random rotations and quantization. While these approaches are focused on training deep networks, we've also designed algorithms for high-dimensional sparse convex models which excel on problems like click-through-rate prediction.

Deploying this technology to millions of heterogenous phones running Gboard requires a sophisticated technology stack. On device training uses a miniature version of TensorFlow. Careful scheduling ensures training happens only when the device is idle, plugged in, and on a free wireless connection, so there is no impact on the phone's performance.
Your phone participates in Federated Learning only
when it won't negatively impact your experience.
The system then needs to communicate and aggregate the model updates in a secure, efficient, scalable, and fault-tolerant way. It's only the combination of research with this infrastructure that makes the benefits of Federated Learning possible.

Federated learning works without the need to store user data in the cloud, but we're not stopping there. We've developed a Secure Aggregation protocol that uses cryptographic techniques so a coordinating server can only decrypt the average update if 100s or 1000s of users have participated — no individual phone's update can be inspected before averaging. It's the first protocol of its kind that is practical for deep-network-sized problems and real-world connectivity constraints. We designed Federated Averaging so the coordinating server only needs the average update, which allows Secure Aggregation to be used; however the protocol is general and can be applied to other problems as well. We're working hard on a production implementation of this protocol and expect to deploy it for Federated Learning applications in the near future.

Our work has only scratched the surface of what is possible. Federated Learning can't solve all machine learning problems (for example, learning to recognize different dog breeds by training on carefully labeled examples), and for many other models the necessary training data is already stored in the cloud (like training spam filters for Gmail). So Google will continue to advance the state-of-the-art for cloud-based ML, but we are also committed to ongoing research to expand the range of problems we can solve with Federated Learning. Beyond Gboard query suggestions, for example, we hope to improve the language models that power your keyboard based on what you actually type on your phone (which can have a style all its own) and photo rankings based on what kinds of photos people look at, share, or delete.

Applying Federated Learning requires machine learning practitioners to adopt new tools and a new way of thinking: model development, training, and evaluation with no direct access to or labeling of raw data, with communication cost as a limiting factor. We believe the user benefits of Federated Learning make tackling the technical challenges worthwhile, and are publishing our work with hopes of a widespread conversation within the machine learning community.

Acknowledgements
This post reflects the work of many people in Google Research, including Blaise Agüera y Arcas, Galen Andrew, Dave Bacon, Keith Bonawitz, Chris Brumme, Arlie Davis, Jac de Haan, Hubert Eichner, Wolfgang Grieskamp, Wei Huang, Vladimir Ivanov, Chloé Kiddon, Jakub Konečný, Nicholas Kong, Ben Kreuter, Alison Lentz, Stefano Mazzocchi, Sarvar Patel, Martin Pelikan, Aaron Segal, Karn Seth, Ananda Theertha Suresh, Iulia Turc, Felix Yu, and our partners in the Gboard team.

On-Device Machine Intelligence



To build the cutting-edge technologies that enable conversational understanding and image recognition, we often apply combinations of machine learning technologies such as deep neural networks and graph-based machine learning. However, the machine learning systems that power most of these applications run in the cloud and are computationally intensive and have significant memory requirements. What if you want machine intelligence to run on your personal phone or smartwatch, or on IoT devices, regardless of whether they are connected to the cloud?

Yesterday, we announced the launch of Android Wear 2.0, along with brand new wearable devices, that will run Google's first entirely “on-device” ML technology for powering smart messaging. This on-device ML system, developed by the Expander research team, enables technologies like Smart Reply to be used for any application, including third-party messaging apps, without ever having to connect with the cloud…so now you can respond to incoming chat messages directly from your watch, with a tap.
The research behind this began last year while our team was developing the machine learning systems that enable conversational understanding capability in Allo and Inbox. The Android Wear team reached out to us and was interested to know whether it would be possible to deploy this Smart Reply technology directly onto a smart device. Because of the limited computing power and memory on smart devices, we quickly realized that it was not possible to do so. Our product manager, Patrick McGregor, realized that this presented a unique challenge and an opportunity for the Expander team to return to the drawing board to design a completely new, lightweight, machine learning architecture — not only to enable Smart Reply on Android Wear, but also to power a wealth of other on-device mobile applications. Together with Tom Rudick, Nathan Beach, and other colleagues from the Android Wear team, we set out to build the new system.

Learning with Projections
A simple strategy to build lightweight conversational models might be to create a small dictionary of common rules (input → reply mappings) on the device and use a naive look-up strategy at inference time. This can work for simple prediction tasks involving a small set of classes using a handful of features (such as binary sentiment classification from text, e.g. “I love this movie” conveys a positive sentiment whereas the sentence “The acting was horrible” is negative). But, it does not scale to complex natural language tasks involving rich vocabularies and the wide language variability observed in chat messages. On the other hand, machine learning models like recurrent neural networks (such as LSTMs), in conjunction with graph learning, have proven to be extremely powerful tools for complex sequence learning in natural language understanding tasks, including Smart Reply. However, compressing such rich models to fit in device memory and produce robust predictions at low computation cost (rapidly on-demand) is extremely challenging. Early experiments with restricting the model to predict only a small handful of replies or using other techniques like quantization or character-level models did not produce useful results.

Instead, we built a different solution for the on-device ML system. We first use a fast, efficient mechanism to group similar incoming messages and project them to similar (“nearby”) bit vector representations. While there are several ways to perform this projection step, such as using word embeddings or encoder networks, we employ a modified version of locality sensitive hashing (LSH) to reduce dimension from millions of unique words to a short, fixed-length sequence of bits. This allows us to compute a projection for an incoming message very fast, on-the-fly, with a small memory footprint on the device since we do not need to store the incoming messages, word embeddings, or even the full model used for training.
Projection step: Similar messages are grouped together and projected to nearby vectors. For example, the messages "hey, how's it going?" and "How's it going buddy?" share similar content and might be projected to the same vector 11100011. Another related message “Howdy, everything going well?” is mapped to a nearby vector 11100110 that differs only in 2 bits.
Next, our system takes the incoming message along with its projections and jointly trains a “message projection model” that learns to predict likely replies using our semi-supervised graph learning framework. The graph learning framework enables training a robust model by combining semantic relationships from multiple sources — message/reply interactions, word/phrase similarity, semantic cluster information — learning useful projection operations that can be mapped to good reply predictions.
Learning step: (Top) Messages along with projections and corresponding replies, if available, are used in a machine learning framework to jointly learn a “message projection model”. (Bottom) The message projection model learns to associate replies with the projections of the corresponding incoming messages. For example, the model projects two different messages “Howdy, everything going well?” and “How’s it going buddy?” (bottom center) to nearby bit vectors and learns to map these to relevant replies (bottom right).
It’s worth noting that while the message projection model can be trained using complex machine learning architectures and the power of the cloud, as described above, the model itself resides and performs inference completely on device. Apps running on the device can pass a user’s incoming messages and receive reply predictions from the on-device model without data leaving the device. The model can also be adapted to cater to the user’s writing style and individual preferences to provide a personalized experience.
Inference step: The model applies the learned projections to an incoming message (or sequence of messages) and suggests relevant and diverse replies. Inference is performed on the device, allowing the model to adapt to user data and personal writing styles.
To get the on-device system to work out of the box, we had to make a few additional improvements such as optimizing for speeding up computations on device and generating rich, diverse replies from the model. We will have a forthcoming scientific publication that describes the on-device machine learning work in more detail.

Converse from Your Wrist
When we embarked on our journey to build this technology from scratch, we weren’t sure if the predictions would be useful or of sufficient quality. We’re quite surprised and excited about how well it works even on Android wearable devices with very limited computation and memory resources. We look forward to continuing to improve the models to provide users with more delightful conversational experiences, and we will be leveraging this on-device ML platform to enable completely new applications in the months to come.

You can now use this feature to respond to your messages directly from your Google watches or any watch that runs Android Wear 2.0. It is already enabled on Google Hangouts, Google Messenger, and many third-party messaging apps. We also provide an API for developers of third-party Wear apps.

Acknowledgements
On behalf of the Google Expander team, I would also like to thank the following people who helped make this technology a success: Andrei Broder, Andrew Tomkins, David Singleton, Mirko Ranieri, Robin Dua and Yicheng Fan.