Tag Archives: On-device Learning

On-Device, Real-Time Hand Tracking with MediaPipe



The ability to perceive the shape and motion of hands can be a vital component in improving the user experience across a variety of technological domains and platforms. For example, it can form the basis for sign language understanding and hand gesture control, and can also enable the overlay of digital content and information on top of the physical world in augmented reality. While coming naturally to people, robust real-time hand perception is a decidedly challenging computer vision task, as hands often occlude themselves or each other (e.g. finger/palm occlusions and hand shakes) and lack high contrast patterns.

Today we are announcing the release of a new approach to hand perception, which we previewed CVPR 2019 in June, implemented in MediaPipe—an open source cross platform framework for building pipelines to process perceptual data of different modalities, such as video and audio. This approach provides high-fidelity hand and finger tracking by employing machine learning (ML) to infer 21 3D keypoints of a hand from just a single frame. Whereas current state-of-the-art approaches rely primarily on powerful desktop environments for inference, our method achieves real-time performance on a mobile phone, and even scales to multiple hands. We hope that providing this hand perception functionality to the wider research and development community will result in an emergence of creative use cases, stimulating new applications and new research avenues.
3D hand perception in real-time on a mobile phone via MediaPipe. Our solution uses machine learning to compute 21 3D keypoints of a hand from a video frame. Depth is indicated in grayscale.
An ML Pipeline for Hand Tracking and Gesture Recognition
Our hand tracking solution utilizes an ML pipeline consisting of several models working together:
  • A palm detector model (called BlazePalm) that operates on the full image and returns an oriented hand bounding box.
  • A hand landmark model that operates on the cropped image region defined by the palm detector and returns high fidelity 3D hand keypoints.
  • A gesture recognizer that classifies the previously computed keypoint configuration into a discrete set of gestures.
This architecture is similar to that employed by our recently published face mesh ML pipeline and that others have used for pose estimation. Providing the accurately cropped palm image to the hand landmark model drastically reduces the need for data augmentation (e.g. rotations, translation and scale) and instead allows the network to dedicate most of its capacity towards coordinate prediction accuracy.
Hand perception pipeline overview.
BlazePalm: Realtime Hand/Palm Detection
To detect initial hand locations, we employ a single-shot detector model called BlazePalm, optimized for mobile real-time uses in a manner similar to BlazeFace, which is also available in MediaPipe. Detecting hands is a decidedly complex task: our model has to work across a variety of hand sizes with a large scale span (~20x) relative to the image frame and be able to detect occluded and self-occluded hands. Whereas faces have high contrast patterns, e.g., in the eye and mouth region, the lack of such features in hands makes it comparatively difficult to detect them reliably from their visual features alone. Instead, providing additional context, like arm, body, or person features, aids accurate hand localization.

Our solution addresses the above challenges using different strategies. First, we train a palm detector instead of a hand detector, since estimating bounding boxes of rigid objects like palms and fists is significantly simpler than detecting hands with articulated fingers. In addition, as palms are smaller objects, the non-maximum suppression algorithm works well even for two-hand self-occlusion cases, like handshakes. Moreover, palms can be modelled using square bounding boxes (anchors in ML terminology) ignoring other aspect ratios, and therefore reducing the number of anchors by a factor of 3-5. Second, an encoder-decoder feature extractor is used for bigger scene context awareness even for small objects (similar to the RetinaNet approach). Lastly, we minimize the focal loss during training to support a large amount of anchors resulting from the high scale variance.

With the above techniques, we achieve an average precision of 95.7% in palm detection. Using a regular cross entropy loss and no decoder gives a baseline of just 86.22%.

Hand Landmark Model
After the palm detection over the whole image our subsequent hand landmark model performs precise keypoint localization of 21 3D hand-knuckle coordinates inside the detected hand regions via regression, that is direct coordinate prediction. The model learns a consistent internal hand pose representation and is robust even to partially visible hands and self-occlusions.

To obtain ground truth data, we have manually annotated ~30K real-world images with 21 3D coordinates, as shown below (we take Z-value from image depth map, if it exists per corresponding coordinate). To better cover the possible hand poses and provide additional supervision on the nature of hand geometry, we also render a high-quality synthetic hand model over various backgrounds and map it to the corresponding 3D coordinates.
Top: Aligned hand crops passed to the tracking network with ground truth annotation. Bottom: Rendered synthetic hand images with ground truth annotation
However, purely synthetic data poorly generalizes to the in-the-wild domain. To overcome this problem, we utilize a mixed training schema. A high-level model training diagram is presented in the following figure.
Mixed training schema for hand tracking network. Cropped real-world photos and rendered synthetic images are used as input to predict 21 3D keypoints.
The table below summarizes regression accuracy depending on the nature of the training data. Using both synthetic and real world data results in a significant performance boost.

Mean regression error
Dataset normalized by palm size
Only real-world 16.1 %
Only rendered synthetic 25.7 %
Mixed real-world + synthetic 13.4 %

Gesture Recognition
On top of the predicted hand skeleton, we apply a simple algorithm to derive the gestures. First, the state of each finger, e.g. bent or straight, is determined by the accumulated angles of joints. Then we map the set of finger states to a set of pre-defined gestures. This straightforward yet effective technique allows us to estimate basic static gestures with reasonable quality. The existing pipeline supports counting gestures from multiple cultures, e.g. American, European, and Chinese, and various hand signs including “Thumb up”, closed fist, “OK”, “Rock”, and “Spiderman”.

Implementation via MediaPipe
With MediaPipe, this perception pipeline can be built as a directed graph of modular components, called Calculators. Mediapipe comes with an extendable set of Calculators to solve tasks like model inference, media processing algorithms, and data transformations across a wide variety of devices and platforms. Individual calculators like cropping, rendering and neural network computations can be performed exclusively on the GPU. For example, we employ TFLite GPU inference on most modern phones.

Our MediaPipe graph for hand tracking is shown below. The graph consists of two subgraphs—one for hand detection and one for hand keypoints (i.e., landmark) computation. One key optimization MediaPipe provides is that the palm detector is only run as necessary (fairly infrequently), saving significant computation time. We achieve this by inferring the hand location in the subsequent video frames from the computed hand key points in the current frame, eliminating the need to run the palm detector over each frame. For robustness, the hand tracker model outputs an additional scalar capturing the confidence that a hand is present and reasonably aligned in the input crop. Only when the confidence falls below a certain threshold is the hand detection model reapplied to the whole frame.
The hand landmark model’s output (REJECT_HAND_FLAG) controls when the hand detection model is triggered. This behavior is achieved by MediaPipe’s powerful synchronization building blocks, resulting in high performance and optimal throughput of the ML pipeline.
A highly efficient ML solution that runs in real-time and across a variety of different platforms and form factors involves significantly more complexities than what the above simplified description captures. To this end, we are open sourcing the above hand tracking and gesture recognition pipeline in the MediaPipe framework, accompanied with the relevant end-to-end usage scenario and source code, here. This provides researchers and developers with a complete stack for experimentation and prototyping of novel ideas based on our model.

Future Directions
We plan to extend this technology with more robust and stable tracking, enlarge the amount of gestures we can reliably detect, and support dynamic gestures unfolding in time. We believe that publishing this technology can give an impulse to new creative ideas and applications by the members of the research and developer community at large. We are excited to see what you can build with it!
Acknowledgements
Special thanks to all our team members who worked on the tech with us: Andrey Vakunov, Andrei Tkachenka, Yury Kartynnik, Artsiom Ablavatski, Ivan Grishchenko, Kanstantsin Sokal‎, Mogan Shieh, Ming Guang Yong, Anastasia Tkach, Jonathan Taylor, Sean Fanello, Sofien Bouaziz, Juhyun Lee‎, Chris McClanahan, Jiuqiang Tang‎, Esha Uboweja‎, Hadon Nash‎, Camillo Lugaresi, Michael Hays, Chuo-Ling Chang, Matsvei Zhdanovich and Matthias Grundmann.

Source: Google AI Blog


Custom On-Device ML Models with Learn2Compress



Successful deep learning models often require significant amounts of computational resources, memory and power to train and run, which presents an obstacle if you want them to perform well on mobile and IoT devices. On-device machine learning allows you to run inference directly on the devices, with the benefits of data privacy and access everywhere, regardless of connectivity. On-device ML systems, such as MobileNets and ProjectionNets, address the resource bottlenecks on mobile devices by optimizing for model efficiency. But what if you wanted to train your own customized, on-device models for your personal mobile application?

Yesterday at Google I/O, we announced ML Kit to make machine learning accessible for all mobile developers. One of the core ML Kit capabilities that will be available soon is an automatic model compression service powered by “Learn2Compress” technology developed by our research team. Learn2Compress enables custom on-device deep learning models in TensorFlow Lite that run efficiently on mobile devices, without developers having to worry about optimizing for memory and speed. We are pleased to make Learn2Compress for image classification available soon through ML Kit. Learn2Compress will be initially available to a small number of developers, and will be offered more broadly in the coming months. You can sign up here if you are interested in using this feature for building your own models.

How it Works
Learn2Compress generalizes the learning framework introduced in previous works like ProjectionNet and incorporates several state-of-the-art techniques for compressing neural network models. It takes as input a large pre-trained TensorFlow model provided by the user, performs training and optimization and automatically generates ready-to-use on-device models that are smaller in size, more memory-efficient, more power-efficient and faster at inference with minimal loss in accuracy.
Learn2Compress for automatically generating on-device ML models.
To do this, Learn2Compress uses multiple neural network optimization and compression techniques including:
  • Pruning reduces model size by removing weights or operations that are least useful for predictions (e.g.low-scoring weights). This can be very effective especially for on-device models involving sparse inputs or outputs, which can be reduced up to 2x in size while retaining 97% of the original prediction quality.
  • Quantization techniques are particularly effective when applied during training and can improve inference speed by reducing the number of bits used for model weights and activations. For example, using 8-bit fixed point representation instead of floats can speed up the model inference, reduce power and further reduce size by 4x.
  • Joint training and distillation approaches follow a teacher-student learning strategy — we use a larger teacher network (in this case, user-provided TensorFlow model) to train a compact student network (on-device model) with minimal loss in accuracy.
    Joint training and distillation approach to learn compact student models.
    The teacher network can be fixed (as in distillation) or jointly optimized, and even train multiple student models of different sizes simultaneously. So instead of a single model, Learn2Compress generates multiple on-device models in a single shot, at different sizes and inference speeds, and lets the developer pick one best suited for their application needs.
These and other techniques like transfer learning also make the compression process more efficient and scalable to large-scale datasets.

How well does it work?
To demonstrate the effectiveness of Learn2Compress, we used it to build compact on-device models of several state-of-the-art deep networks used in image and natural language tasks such as MobileNets, NASNet, Inception, ProjectionNet, among others. For a given task and dataset, we can generate multiple on-device models at different inference speeds and model sizes.
Accuracy at various sizes for Learn2Compress models and full-sized baseline networks on CIFAR-10 (left) and ImageNet (right) image classification tasks. Student networks used to produce the compressed variants for CIFAR-10 and ImageNet are modeled using NASNet and MobileNet-inspired architectures, respectively.
For image classification, Learn2Compress can generate small and fast models with good prediction accuracy suited for mobile applications. For example, on ImageNet task, Learn2Compress achieves a model 22x smaller than Inception v3 baseline and 4x smaller than MobileNet v1 baseline with just 4.6-7% drop in accuracy. On CIFAR-10, jointly training multiple Learn2Compress models with shared parameters, takes only 10% more time than training a single Learn2Compress large model, but yields 3 compressed models that are upto 94x smaller in size and upto 27x faster with up to 36x lower cost and good prediction quality (90-95% top-1 accuracy).
Computation cost and average prediction latency (on Pixel phone) for baseline and Learn2Compress models on CIFAR-10 image classification task. Learn2Compress-optimized models use NASNet-style network architecture.
We are also excited to see how well this performs on developer use-cases. For example, Fishbrain, a social platform for fishing enthusiasts, used Learn2Compress to compress their existing image classification cloud model (80MB+ in size and 91.8% top-3 accuracy) to a much smaller on-device model, less than 5MB in size, with similar accuracy. In some cases, we observe that it is possible for the compressed models to even slightly outperform the original large model’s accuracy due to better regularization effects.

We will continue to improve Learn2Compress with future advances in ML and deep learning, and extend to more use-cases beyond image classification. We are excited and looking forward to make this available soon through ML Kit’s compression service on the Cloud. We hope this will make it easy for developers to automatically build and optimize their own on-device ML models so that they can focus on building great apps and cool user experiences involving computer vision, natural language and other machine learning applications.

Acknowledgments
I would like to acknowledge our core contributors Gaurav Menghani, Prabhu Kaliamoorthi and Yicheng Fan along with Wei Chai, Kang Lee, Sheng Xu and Pannag Sanketi. Special thanks to Dave Burke, Brahim Elbouchikhi, Hrishikesh Aradhye, Hugues Vincent, and Arun Venkatesan from the Android team; Sachin Kotwani, Wesley Tarle, Pavel Jbanov and from the Firebase team; Andrei Broder, Andrew Tomkins, Robin Dua, Patrick McGregor, Gaurav Nemade, the Google Expander team and TensorFlow team.


Source: Google AI Blog


Introducing the CVPR 2018 On-Device Visual Intelligence Challenge



Over the past year, there have been exciting innovations in the design of deep networks for vision applications on mobile devices, such as the MobileNet model family and integer quantization. Many of these innovations have been driven by performance metrics that focus on meaningful user experiences in real-world mobile applications, requiring inference to be both low-latency and accurate. While the accuracy of a deep network model can be conveniently estimated with well established benchmarks in the computer vision community, latency is surprisingly difficult to measure and no uniform metric has been established. This lack of measurement platforms and uniform metrics have hampered the development of performant mobile applications.

Today, we are happy to announce the On-device Visual Intelligence Challenge (OVIC), part of the Low-Power Image Recognition Challenge Workshop at the 2018 Computer Vision and Pattern Recognition conference (CVPR2018). A collaboration with Purdue University, the University of North Carolina and IEEE, OVIC is a public competition for real-time image classification that uses state-of-the-art Google technology to significantly lower the barrier to entry for mobile development. OVIC provides two key features to catalyze innovation: a unified latency metric and an evaluation platform.

A Unified Metric
OVIC focuses on the establishment of a unified metric aligned directly with accurate and performant operation on mobile devices. The metric is defined as the number of correct classifications within a specified per-image average time limit of 33ms. This latency limit allows every frame in a live 30 frames-per-second video to be processed, thus providing a seamless user experience1. Prior to OVIC, it was tricky to enforce such a limit due to the difficulty in accurately and uniformly measuring latency as would be experienced in real-world applications on real-world devices. Without a repeatable mobile development platform, researchers have relied primarily on approximate metrics for latency that are convenient to compute, such as the number of multiply-accumulate operations (MACs). The intuition is that multiply-accumulate constitutes the most time-consuming operation in a deep neural network, so their count should be indicative of the overall latency. However, these metrics are often poor predictors of on-device latency due to many aspects of the models that can impact the average latency of each MAC in typical implementations.
Even though the number of multiply-accumulate operations (# MACs) is the most commonly used metric to approximate on-device latency, it is a poor predictor of latency. Using data from various quantized and floating point MobileNet V1 and V2 based models, this graph plots on-device latency on a common reference device versus the number of MACs. It is clear that models with similar latency can have very different MACs, and vice versa.
The graph above shows that while the number of MACs is correlated with the inference latency, there is significant variation in the mapping. Thus number of MACs is a poor proxy for latency, and since latency directly affects users’ experiences, we believe it is paramount to optimize latency directly rather than focusing on limiting the number of MACs as a proxy.

An Evaluation Platform
As mentioned above, a primary issue with latency is that it has previously been challenging to measure reliably and repeatably, due to variations in implementation, running environment and hardware architectures. Recent successes in mobile development overcome these challenges with the help of a convenient mobile development platform, including optimized kernels for mobile CPUs, light-weight portable model formats, increasingly capable mobile devices, and more. However, these various platforms have traditionally required resources and development capabilities that are only available to larger universities and industry.

With that in mind, we are releasing OVIC’s evaluation platform that includes a number of components designed to make mobile development and evaluations that can be replicated and compared accessible to the broader research community:
  • TOCO compiler for optimizing TensorFlow models for efficient inference
  • TensorFlow Lite inference engine for mobile deployment
  • A benchmarking SDK that can be run locally on any Android phone
  • Sample models to showcase successful mobile architectures that run inference in floating-point and quantized modes
  • Google’s benchmarking tool for reliable latency measurements on specific Pixel phones (available to registered contestants).
Using these tools available in OVIC, a participant can conveniently incorporate measurement of on-device latency into their design loop without having to worry about optimizing kernels, purchasing latency/power measurement devices, or designing the framework to drive them. The only requirement for entry is experiences with training computer vision models in TensorFlow, which can be found in this tutorial.

With OVIC, we encourage the entire research community to improve the classification performance of low-latency high-accuracy models towards new frontiers, as shown in the following graphic.
Sampling of current MobileNet mobile models illustrating the tradeoff between increased accuracy and reduced latency.
We cordially invite you to participate here before the deadline on June 15th, and help us discover new mobile vision architectures that will propel development into the future.

Acknowledgements
We would like to acknowledge our core contributors Achille Brighton, Alec Go, Andrew Howard, Hartwig Adam, Mark Sandler and Xiao Zhang. We would also like to acknowledge our external collaborators Alex Berg and Yung-Hsiang Lu. We give special thanks to Andre Hentz, Andrew Selle, Benoit Jacob, Brad Krueger, Dmitry Kalenichenko, Megan Cummins, Pete Warden, Rajat Monga, Shiyu Hu and Yicheng Fan.


1 Alternatively the same metric could encourage even lower power operation by only processing a subset of the images in the input stream.



Source: Google AI Blog


Introducing the CVPR 2018 On-Device Visual Intelligence Challenge



Over the past year, there have been exciting innovations in the design of deep networks for vision applications on mobile devices, such as the MobileNet model family and integer quantization. Many of these innovations have been driven by performance metrics that focus on meaningful user experiences in real-world mobile applications, requiring inference to be both low-latency and accurate. While the accuracy of a deep network model can be conveniently estimated with well established benchmarks in the computer vision community, latency is surprisingly difficult to measure and no uniform metric has been established. This lack of measurement platforms and uniform metrics have hampered the development of performant mobile applications.

Today, we are happy to announce the On-device Visual Intelligence Challenge (OVIC), part of the Low-Power Image Recognition Challenge Workshop at the 2018 Computer Vision and Pattern Recognition conference (CVPR2018). A collaboration with Purdue University, the University of North Carolina and IEEE, OVIC is a public competition for real-time image classification that uses state-of-the-art Google technology to significantly lower the barrier to entry for mobile development. OVIC provides two key features to catalyze innovation: a unified latency metric and an evaluation platform.

A Unified Metric
OVIC focuses on the establishment of a unified metric aligned directly with accurate and performant operation on mobile devices. The metric is defined as the number of correct classifications within a specified per-image average time limit of 33ms. This latency limit allows every frame in a live 30 frames-per-second video to be processed, thus providing a seamless user experience1. Prior to OVIC, it was tricky to enforce such a limit due to the difficulty in accurately and uniformly measuring latency as would be experienced in real-world applications on real-world devices. Without a repeatable mobile development platform, researchers have relied primarily on approximate metrics for latency that are convenient to compute, such as the number of multiply-accumulate operations (MACs). The intuition is that multiply-accumulate constitutes the most time-consuming operation in a deep neural network, so their count should be indicative of the overall latency. However, these metrics are often poor predictors of on-device latency due to many aspects of the models that can impact the average latency of each MAC in typical implementations.
Even though the number of multiply-accumulate operations (# MACs) is the most commonly used metric to approximate on-device latency, it is a poor predictor of latency. Using data from various quantized and floating point MobileNet V1 and V2 based models, this graph plots on-device latency on a common reference device versus the number of MACs. It is clear that models with similar latency can have very different MACs, and vice versa.
The graph above shows that while the number of MACs is correlated with the inference latency, there is significant variation in the mapping. Thus number of MACs is a poor proxy for latency, and since latency directly affects users’ experiences, we believe it is paramount to optimize latency directly rather than focusing on limiting the number of MACs as a proxy.

An Evaluation Platform
As mentioned above, a primary issue with latency is that it has previously been challenging to measure reliably and repeatably, due to variations in implementation, running environment and hardware architectures. Recent successes in mobile development overcome these challenges with the help of a convenient mobile development platform, including optimized kernels for mobile CPUs, light-weight portable model formats, increasingly capable mobile devices, and more. However, these various platforms have traditionally required resources and development capabilities that are only available to larger universities and industry.

With that in mind, we are releasing OVIC’s evaluation platform that includes a number of components designed to make mobile development and evaluations that can be replicated and compared accessible to the broader research community:
  • TOCO compiler for optimizing TensorFlow models for efficient inference
  • TensorFlow Lite inference engine for mobile deployment
  • A benchmarking SDK that can be run locally on any Android phone
  • Sample models to showcase successful mobile architectures that run inference in floating-point and quantized modes
  • Google’s benchmarking tool for reliable latency measurements on specific Pixel phones (available to registered contestants).
Using these tools available in OVIC, a participant can conveniently incorporate measurement of on-device latency into their design loop without having to worry about optimizing kernels, purchasing latency/power measurement devices, or designing the framework to drive them. The only requirement for entry is experiences with training computer vision models in TensorFlow, which can be found in this tutorial.

With OVIC, we encourage the entire research community to improve the classification performance of low-latency high-accuracy models towards new frontiers, as shown in the following graphic.
Sampling of current MobileNet mobile models illustrating the tradeoff between increased accuracy and reduced latency.
We cordially invite you to participate here before the deadline on June 15th, and help us discover new mobile vision architectures that will propel development into the future.

Acknowledgements
We would like to acknowledge our core contributors Achille Brighton, Alec Go, Andrew Howard, Hartwig Adam, Mark Sandler and Xiao Zhang. We would also like to acknowledge our external collaborators Alex Berg and Yung-Hsiang Lu. We give special thanks to Andre Hentz, Andrew Selle, Benoit Jacob, Brad Krueger, Dmitry Kalenichenko, Megan Cummins, Pete Warden, Rajat Monga, Shiyu Hu and Yicheng Fan.


1 Alternatively the same metric could encourage even lower power operation by only processing a subset of the images in the input stream.



On-Device Conversational Modeling with TensorFlow Lite



Earlier this year, we launched Android Wear 2.0 which featured the first "on-device" machine learning technology for smart messaging. This enabled cloud-based technologies like Smart Reply, previously available in Gmail, Inbox and Allo, to be used directly within any application for the first time, including third-party messaging apps, without ever having to connect to the cloud. So you can respond to incoming chat messages on the go, directly from your smartwatch.

Today, we announce TensorFlow Lite, TensorFlow’s lightweight solution for mobile and embedded devices. This framework is optimized for low-latency inference of machine learning models, with a focus on small memory footprint and fast performance. As part of the library, we have also released an on-device conversational model and a demo app that provides an example of a natural language application powered by TensorFlow Lite, in order to make it easier for developers and researchers to build new machine intelligence features powered by on-device inference. This model generates reply suggestions to input conversational chat messages, with efficient inference that can be easily plugged in to your chat application to power on-device conversational intelligence.

The on-device conversational model we have released uses a new ML architecture for training compact neural networks (as well as other machine learning models) based on a joint optimization framework, originally presented in ProjectionNet: Learning Efficient On-Device Deep Networks Using Neural Projections. This architecture can run efficiently on mobile devices with limited computing power and memory, by using efficient “projection” operations that transform any input to a compact bit vector representation — similar inputs are projected to nearby vectors that are dense or sparse depending on type of projection. For example, the messages “hey, how's it going?” and “How's it going buddy?”, might be projected to the same vector representation.

Using this idea, the conversational model combines these efficient operations at low computation and memory footprint. We trained this on-device model end-to-end using an ML framework that jointly trains two types of models — a compact projection model (as described above) combined with a trainer model. The two models are trained in a joint fashion, where the projection model learns from the trainer model — the trainer is characteristic of an expert and modeled using larger and more complex ML architectures, whereas the projection model resembles a student that learns from the expert. During training, we can also stack other techniques such as quantization or distillation to achieve further compression or selectively optimize certain portions of the objective function. Once trained, the smaller projection model is able to be used directly for inference on device.
For inference, the trained projection model is compiled into a set of TensorFlow Lite operations that have been optimized for fast execution on mobile platforms and executed directly on device. The TensorFlow Lite inference graph for the on-device conversational model is shown here.
TensorFlow Lite execution for the On-Device Conversational Model.
The open-source conversational model released today (along with code) was trained end-to-end using the joint ML architecture described above. Today’s release also includes a demo app, so you can easily download and try out one-touch smart replies on your mobile device. The architecture enables easy configuration for model size and prediction quality based on application needs. You can find a list of sample messages where this model does well here. The system can also fall back to suggesting replies from a fixed set that was learned and compiled from popular response intents observed in chat conversations. The underlying model is different from the ones Google uses for Smart Reply responses in its apps1.

Beyond Conversational Models
Interestingly, the ML architecture described above permits flexible choices for the underlying model. We also designed the architecture to be compatible with different machine learning approaches — for example, when used with TensorFlow deep learning, we learn a lightweight neural network (ProjectionNet) for the underlying model, whereas a different architecture (ProjectionGraph) represents the model using a graph framework instead of a neural network.

The joint framework can also be used to train lightweight on-device models for other tasks using different ML modeling architectures. As an example, we derived a ProjectionNet architecture that uses a complex feed-forward or recurrent architecture (like LSTM) for the trainer model coupled with a simple projection architecture comprised of dynamic projection operations and a few, narrow fully-connected layers. The whole architecture is trained end-to-end using backpropagation in TensorFlow and once trained, the compact ProjectionNet is directly used for inference. Using this method, we have successfully trained tiny ProjectionNet models that achieve significant reduction in model sizes (up to several orders of magnitude reduction) and high performance with respect to accuracy on multiple visual and language classification tasks (a few examples here). Similarly, we trained other lightweight models using our graph learning framework, even in semi-supervised settings.
ML architecture for training on-device models: ProjectionNet trained using deep learning (left), and ProjectionGraph trained using graph learning (right).
We will continue to improve and release updated TensorFlow Lite models in open-source. We think that the released model (as well as future models) learned using these ML architectures may be reused for many natural language and computer vision applications or plugged into existing apps for enabling machine intelligence. We hope that the machine learning and natural language processing communities will be able to build on these to address new problems and use-cases we have not yet conceived.

Acknowledgments
Yicheng Fan and Gaurav Nemade contributed immensely to this effort. Special thanks to Rajat Monga, Andre Hentz, Andrew Selle, Sarah Sirajuddin, and Anitha Vijayakumar from the TensorFlow team; Robin Dua, Patrick McGregor, Andrei Broder, Andrew Tomkins and the Google Expander team.



1 The released on-device model was trained to optimize for small size and low latency applications on mobile phones and wearables. Smart Reply predictions in Google apps, however are generated using larger, more complex models. In production systems, we also use multiple classifiers that are trained to detect inappropriate content and apply further filtering and tuning to optimize user experience and quality levels. We recommend that developers using the open-source TensorFlow Lite version also follow such practices for their end applications.

Federated Learning: Collaborative Machine Learning without Centralized Training Data



Standard machine learning approaches require centralizing the training data on one machine or in a datacenter. And Google has built one of the most secure and robust cloud infrastructures for processing this data to make our services better. Now for models trained from user interaction with mobile devices, we're introducing an additional approach: Federated Learning.

Federated Learning enables mobile phones to collaboratively learn a shared prediction model while keeping all the training data on device, decoupling the ability to do machine learning from the need to store the data in the cloud. This goes beyond the use of local models that make predictions on mobile devices (like the Mobile Vision API and On-Device Smart Reply) by bringing model training to the device as well.

It works like this: your device downloads the current model, improves it by learning from data on your phone, and then summarizes the changes as a small focused update. Only this update to the model is sent to the cloud, using encrypted communication, where it is immediately averaged with other user updates to improve the shared model. All the training data remains on your device, and no individual updates are stored in the cloud.
Your phone personalizes the model locally, based on your usage (A). Many users' updates are aggregated (B) to form a consensus change (C) to the shared model, after which the procedure is repeated.
Federated Learning allows for smarter models, lower latency, and less power consumption, all while ensuring privacy. And this approach has another immediate benefit: in addition to providing an update to the shared model, the improved model on your phone can also be used immediately, powering experiences personalized by the way you use your phone.

We're currently testing Federated Learning in Gboard on Android, the Google Keyboard. When Gboard shows a suggested query, your phone locally stores information about the current context and whether you clicked the suggestion. Federated Learning processes that history on-device to suggest improvements to the next iteration of Gboard’s query suggestion model.
To make Federated Learning possible, we had to overcome many algorithmic and technical challenges. In a typical machine learning system, an optimization algorithm like Stochastic Gradient Descent (SGD) runs on a large dataset partitioned homogeneously across servers in the cloud. Such highly iterative algorithms require low-latency, high-throughput connections to the training data. But in the Federated Learning setting, the data is distributed across millions of devices in a highly uneven fashion. In addition, these devices have significantly higher-latency, lower-throughput connections and are only intermittently available for training.

These bandwidth and latency limitations motivate our Federated Averaging algorithm, which can train deep networks using 10-100x less communication compared to a naively federated version of SGD. The key idea is to use the powerful processors in modern mobile devices to compute higher quality updates than simple gradient steps. Since it takes fewer iterations of high-quality updates to produce a good model, training can use much less communication. As upload speeds are typically much slower than download speeds, we also developed a novel way to reduce upload communication costs up to another 100x by compressing updates using random rotations and quantization. While these approaches are focused on training deep networks, we've also designed algorithms for high-dimensional sparse convex models which excel on problems like click-through-rate prediction.

Deploying this technology to millions of heterogenous phones running Gboard requires a sophisticated technology stack. On device training uses a miniature version of TensorFlow. Careful scheduling ensures training happens only when the device is idle, plugged in, and on a free wireless connection, so there is no impact on the phone's performance.
Your phone participates in Federated Learning only
when it won't negatively impact your experience.
The system then needs to communicate and aggregate the model updates in a secure, efficient, scalable, and fault-tolerant way. It's only the combination of research with this infrastructure that makes the benefits of Federated Learning possible.

Federated learning works without the need to store user data in the cloud, but we're not stopping there. We've developed a Secure Aggregation protocol that uses cryptographic techniques so a coordinating server can only decrypt the average update if 100s or 1000s of users have participated — no individual phone's update can be inspected before averaging. It's the first protocol of its kind that is practical for deep-network-sized problems and real-world connectivity constraints. We designed Federated Averaging so the coordinating server only needs the average update, which allows Secure Aggregation to be used; however the protocol is general and can be applied to other problems as well. We're working hard on a production implementation of this protocol and expect to deploy it for Federated Learning applications in the near future.

Our work has only scratched the surface of what is possible. Federated Learning can't solve all machine learning problems (for example, learning to recognize different dog breeds by training on carefully labeled examples), and for many other models the necessary training data is already stored in the cloud (like training spam filters for Gmail). So Google will continue to advance the state-of-the-art for cloud-based ML, but we are also committed to ongoing research to expand the range of problems we can solve with Federated Learning. Beyond Gboard query suggestions, for example, we hope to improve the language models that power your keyboard based on what you actually type on your phone (which can have a style all its own) and photo rankings based on what kinds of photos people look at, share, or delete.

Applying Federated Learning requires machine learning practitioners to adopt new tools and a new way of thinking: model development, training, and evaluation with no direct access to or labeling of raw data, with communication cost as a limiting factor. We believe the user benefits of Federated Learning make tackling the technical challenges worthwhile, and are publishing our work with hopes of a widespread conversation within the machine learning community.

Acknowledgements
This post reflects the work of many people in Google Research, including Blaise Agüera y Arcas, Galen Andrew, Dave Bacon, Keith Bonawitz, Chris Brumme, Arlie Davis, Jac de Haan, Hubert Eichner, Wolfgang Grieskamp, Wei Huang, Vladimir Ivanov, Chloé Kiddon, Jakub Konečný, Nicholas Kong, Ben Kreuter, Alison Lentz, Stefano Mazzocchi, Sarvar Patel, Martin Pelikan, Aaron Segal, Karn Seth, Ananda Theertha Suresh, Iulia Turc, Felix Yu, and our partners in the Gboard team.

On-Device Machine Intelligence



To build the cutting-edge technologies that enable conversational understanding and image recognition, we often apply combinations of machine learning technologies such as deep neural networks and graph-based machine learning. However, the machine learning systems that power most of these applications run in the cloud and are computationally intensive and have significant memory requirements. What if you want machine intelligence to run on your personal phone or smartwatch, or on IoT devices, regardless of whether they are connected to the cloud?

Yesterday, we announced the launch of Android Wear 2.0, along with brand new wearable devices, that will run Google's first entirely “on-device” ML technology for powering smart messaging. This on-device ML system, developed by the Expander research team, enables technologies like Smart Reply to be used for any application, including third-party messaging apps, without ever having to connect with the cloud…so now you can respond to incoming chat messages directly from your watch, with a tap.
The research behind this began last year while our team was developing the machine learning systems that enable conversational understanding capability in Allo and Inbox. The Android Wear team reached out to us and was interested to know whether it would be possible to deploy this Smart Reply technology directly onto a smart device. Because of the limited computing power and memory on smart devices, we quickly realized that it was not possible to do so. Our product manager, Patrick McGregor, realized that this presented a unique challenge and an opportunity for the Expander team to return to the drawing board to design a completely new, lightweight, machine learning architecture — not only to enable Smart Reply on Android Wear, but also to power a wealth of other on-device mobile applications. Together with Tom Rudick, Nathan Beach, and other colleagues from the Android Wear team, we set out to build the new system.

Learning with Projections
A simple strategy to build lightweight conversational models might be to create a small dictionary of common rules (input → reply mappings) on the device and use a naive look-up strategy at inference time. This can work for simple prediction tasks involving a small set of classes using a handful of features (such as binary sentiment classification from text, e.g. “I love this movie” conveys a positive sentiment whereas the sentence “The acting was horrible” is negative). But, it does not scale to complex natural language tasks involving rich vocabularies and the wide language variability observed in chat messages. On the other hand, machine learning models like recurrent neural networks (such as LSTMs), in conjunction with graph learning, have proven to be extremely powerful tools for complex sequence learning in natural language understanding tasks, including Smart Reply. However, compressing such rich models to fit in device memory and produce robust predictions at low computation cost (rapidly on-demand) is extremely challenging. Early experiments with restricting the model to predict only a small handful of replies or using other techniques like quantization or character-level models did not produce useful results.

Instead, we built a different solution for the on-device ML system. We first use a fast, efficient mechanism to group similar incoming messages and project them to similar (“nearby”) bit vector representations. While there are several ways to perform this projection step, such as using word embeddings or encoder networks, we employ a modified version of locality sensitive hashing (LSH) to reduce dimension from millions of unique words to a short, fixed-length sequence of bits. This allows us to compute a projection for an incoming message very fast, on-the-fly, with a small memory footprint on the device since we do not need to store the incoming messages, word embeddings, or even the full model used for training.
Projection step: Similar messages are grouped together and projected to nearby vectors. For example, the messages "hey, how's it going?" and "How's it going buddy?" share similar content and might be projected to the same vector 11100011. Another related message “Howdy, everything going well?” is mapped to a nearby vector 11100110 that differs only in 2 bits.
Next, our system takes the incoming message along with its projections and jointly trains a “message projection model” that learns to predict likely replies using our semi-supervised graph learning framework. The graph learning framework enables training a robust model by combining semantic relationships from multiple sources — message/reply interactions, word/phrase similarity, semantic cluster information — learning useful projection operations that can be mapped to good reply predictions.
Learning step: (Top) Messages along with projections and corresponding replies, if available, are used in a machine learning framework to jointly learn a “message projection model”. (Bottom) The message projection model learns to associate replies with the projections of the corresponding incoming messages. For example, the model projects two different messages “Howdy, everything going well?” and “How’s it going buddy?” (bottom center) to nearby bit vectors and learns to map these to relevant replies (bottom right).
It’s worth noting that while the message projection model can be trained using complex machine learning architectures and the power of the cloud, as described above, the model itself resides and performs inference completely on device. Apps running on the device can pass a user’s incoming messages and receive reply predictions from the on-device model without data leaving the device. The model can also be adapted to cater to the user’s writing style and individual preferences to provide a personalized experience.
Inference step: The model applies the learned projections to an incoming message (or sequence of messages) and suggests relevant and diverse replies. Inference is performed on the device, allowing the model to adapt to user data and personal writing styles.
To get the on-device system to work out of the box, we had to make a few additional improvements such as optimizing for speeding up computations on device and generating rich, diverse replies from the model. We will have a forthcoming scientific publication that describes the on-device machine learning work in more detail.

Converse from Your Wrist
When we embarked on our journey to build this technology from scratch, we weren’t sure if the predictions would be useful or of sufficient quality. We’re quite surprised and excited about how well it works even on Android wearable devices with very limited computation and memory resources. We look forward to continuing to improve the models to provide users with more delightful conversational experiences, and we will be leveraging this on-device ML platform to enable completely new applications in the months to come.

You can now use this feature to respond to your messages directly from your Google watches or any watch that runs Android Wear 2.0. It is already enabled on Google Hangouts, Google Messenger, and many third-party messaging apps. We also provide an API for developers of third-party Wear apps.

Acknowledgements
On behalf of the Google Expander team, I would also like to thank the following people who helped make this technology a success: Andrei Broder, Andrew Tomkins, David Singleton, Mirko Ranieri, Robin Dua and Yicheng Fan.