Tag Archives: On-device Learning

Advances in private training for production on-device language models

Language models (LMs) trained to predict the next word given input text are the key technology for many applications [1, 2]. In Gboard, LMs are used to improve users’ typing experience by supporting features like next word prediction (NWP), Smart Compose, smart completion and suggestion, slide to type, and proofread. Deploying models on users’ devices rather than enterprise servers has advantages like lower latency and better privacy for model usage. While training on-device models directly from user data effectively improves the utility performance for applications such as NWP and smart text selection, protecting the privacy of user data for model training is important.

Gboard features powered by on-device language models.

In this blog we discuss how years of research advances now power the private training of Gboard LMs, since the proof-of-concept development of federated learning (FL) in 2017 and formal differential privacy (DP) guarantees in 2022. FL enables mobile phones to collaboratively learn a model while keeping all the training data on device, and DP provides a quantifiable measure of data anonymization. Formally, DP is often characterized by (ε, δ) with smaller values representing stronger guarantees. Machine learning (ML) models are considered to have reasonable DP guarantees for ε=10 and strong DP guarantees for ε=1 when δ is small.

As of today, all NWP neural network LMs in Gboard are trained with FL with formal DP guarantees, and all future launches of Gboard LMs trained on user data require DP. These 30+ Gboard on-device LMs are launched in 7+ languages and 15+ countries, and satisfy (ɛ, δ)-DP guarantees of small δ of 10-10 and ɛ between 0.994 and 13.69. To the best of our knowledge, this is the largest known deployment of user-level DP in production at Google or anywhere, and the first time a strong DP guarantee of ɛ < 1 is announced for models trained directly on user data.

Privacy principles and practices in Gboard

In “Private Federated Learning in Gboard”, we discussed how different privacy principles are currently reflected in production models, including:

  • Transparency and user control: We provide disclosure of what data is used, what purpose it is used for, how it is processed in various channels, and how Gboard users can easily configure the data usage in learning models.
  • Data minimization: FL immediately aggregates only focused updates that improve a specific model. Secure aggregation (SecAgg) is an encryption method to further guarantee that only aggregated results of the ephemeral updates can be accessed.
  • Data anonymization: DP is applied by the server to prevent models from memorizing the unique information in individual user’s training data.
  • Auditability and verifiability: We have made public the key algorithmic approaches and privacy accounting in open-sourced code (TFF aggregator, TFP DPQuery, DP accounting, and FL system).

A brief history

In recent years, FL has become the default method for training Gboard on-device LMs from user data. In 2020, a DP mechanism that clips and adds noise to model updates was used to prevent memorization for training the Spanish LM in Spain, which satisfies finite DP guarantees (Tier 3 described in “How to DP-fy ML“ guide). In 2022, with the help of the DP-Follow-The-Regularized-Leader (DP-FTRL) algorithm, the Spanish LM became the first production neural network trained directly on user data announced with a formal DP guarantee of (ε=8.9, δ=10-10)-DP (equivalent to the reported ρ=0.81 zero-Concentrated-Differential-Privacy), and therefore satisfies reasonable privacy guarantees (Tier 2).

Differential privacy by default in federated learning

In “Federated Learning of Gboard Language Models with Differential Privacy”, we announced that all the NWP neural network LMs in Gboard have DP guarantees, and all future launches of Gboard LMs trained on user data require DP guarantees. DP is enabled in FL by applying the following practices:

  • Pre-train the model with the multilingual C4 dataset.
  • Via simulation experiments on public datasets, find a large DP-noise-to-signal ratio that allows for high utility. Increasing the number of clients contributing to one round of model update improves privacy while keeping the noise ratio fixed for good utility, up to the point the DP target is met, or the maximum allowed by the system and the size of the population.
  • Configure the parameter to restrict the frequency each client can contribute (e.g., once every few days) based on computation budget and estimated population in the FL system.
  • Run DP-FTRL training with limits on the magnitude of per-device updates chosen either via adaptive clipping, or fixed based on experience.

SecAgg can be additionally applied by adopting the advances in improving computation and communication for scales and sensitivity.

Federated learning with differential privacy and (SecAgg).

Reporting DP guarantees

The DP guarantees of launched Gboard NWP LMs are visualized in the barplot below. The x-axis shows LMs labeled by language-locale and trained on corresponding populations; the y-axis shows the ε value when δ is fixed to a small value of 10-10 for (ε, δ)-DP (lower is better). The utility of these models are either significantly better than previous non-neural models in production, or comparable with previous LMs without DP, measured based on user-interactions metrics during A/B testing. For example, by applying the best practices, the DP guarantee of the Spanish model in Spain is improved from ε=8.9 to ε=5.37. SecAgg is additionally used for training the Spanish model in Spain and English model in the US. More details of the DP guarantees are reported in the appendix following the guidelines outlined in “How to DP-fy ML”.

Towards stronger DP guarantees

The ε~10 DP guarantees of many launched LMs are already considered reasonable for ML models in practice, while the journey of DP FL in Gboard continues for improving user typing experience while protecting data privacy. We are excited to announce that, for the first time, production LMs of Portuguese in Brazil and Spanish in Latin America are trained and launched with a DP guarantee of ε ≤ 1, which satisfies Tier 1 strong privacy guarantees. Specifically, the (ε=0.994, δ=10-10)-DP guarantee is achieved by running the advanced Matrix Factorization DP-FTRL (MF-DP-FTRL) algorithm, with 12,000+ devices participating in every training round of server model update larger than the common setting of 6500+ devices, and a carefully configured policy to restrict each client to at most participate twice in the total 2000 rounds of training in 14 days in the large Portuguese user population of Brazil. Using a similar setting, the es-US Spanish LM was trained in a large population combining multiple countries in Latin America to achieve (ε=0.994, δ=10-10)-DP. The ε ≤ 1 es-US model significantly improved the utility in many countries, and launched in Colombia, Ecuador, Guatemala, Mexico, and Venezuela. For the smaller population in Spain, the DP guarantee of es-ES LM is improved from ε=5.37 to ε=3.42 by only replacing DP-FTRL with MF-DP-FTRL without increasing the number of devices participating every round. More technical details are disclosed in the colab for privacy accounting.

DP guarantees for Gboard NWP LMs (the purple bar represents the first es-ES launch of ε=8.9; cyan bars represent privacy improvements for models trained with MF-DP-FTRL; tiers are from “How to DP-fy ML“ guide; en-US* and es-ES* are additionally trained with SecAgg).

Discussion and next steps

Our experience suggests that DP can be achieved in practice through system algorithm co-design on client participation, and that both privacy and utility can be strong when populations are large and a large number of devices' contributions are aggregated. Privacy-utility-computation trade-offs can be improved by using public data, the new MF-DP-FTRL algorithm, and tightening accounting. With these techniques, a strong DP guarantee of ε ≤ 1 is possible but still challenging. Active research on empirical privacy auditing [1, 2] suggests that DP models are potentially more private than the worst-case DP guarantees imply. While we keep pushing the frontier of algorithms, which dimension of privacy-utility-computation should be prioritized?

We are actively working on all privacy aspects of ML, including extending DP-FTRL to distributed DP and improving auditability and verifiability. Trusted Execution Environment opens the opportunity for substantially increasing the model size with verifiable privacy. The recent breakthrough in large LMs (LLMs) motivates us to rethink the usage of public information in private training and more future interactions between LLMs, on-device LMs, and Gboard production.


The authors would like to thank Peter Kairouz, Brendan McMahan, and Daniel Ramage for their early feedback on the blog post itself, Shaofeng Li and Tom Small for helping with the animated figures, and the teams at Google that helped with algorithm design, infrastructure implementation, and production maintenance. The collaborators below directly contribute to the presented results:

Research and algorithm development: Galen Andrew, Stanislav Chiknavaryan, Christopher A. Choquette-Choo, Arun Ganesh, Peter Kairouz, Ryan McKenna, H. Brendan McMahan, Jesse Rosenstock, Timon Van Overveldt, Keith Rush, Shuang Song, Thomas Steinke, Abhradeep Guha Thakurta, Om Thakkar, and Yuanbo Zhang.

Infrastructure, production and leadership support: Mingqing Chen, Stefan Dierauf, Billy Dou, Hubert Eichner, Zachary Garrett, Jeremy Gillula, Jianpeng Hou, Hui Li, Xu Liu, Wenzhi Mao, Brett McLarnon, Mengchen Pei, Daniel Ramage, Swaroop Ramaswamy, Haicheng Sun, Andreas Terzis, Yun Wang, Shanshan Wu, Yu Xiao, and Shumin Zhai.

Source: Google AI Blog

MobileDiffusion: Rapid text-to-image generation on-device

Text-to-image diffusion models have shown exceptional capabilities in generating high-quality images from text prompts. However, leading models feature billions of parameters and are consequently expensive to run, requiring powerful desktops or servers (e.g., Stable Diffusion, DALL·E, and Imagen). While recent advancements in inference solutions on Android via MediaPipe and iOS via Core ML have been made in the past year, rapid (sub-second) text-to-image generation on mobile devices has remained out of reach.

To that end, in “MobileDiffusion: Subsecond Text-to-Image Generation on Mobile Devices”, we introduce a novel approach with the potential for rapid text-to-image generation on-device. MobileDiffusion is an efficient latent diffusion model specifically designed for mobile devices. We also adopt DiffusionGAN to achieve one-step sampling during inference, which fine-tunes a pre-trained diffusion model while leveraging a GAN to model the denoising step. We have tested MobileDiffusion on iOS and Android premium devices, and it can run in half a second to generate a 512x512 high-quality image. Its comparably small model size of just 520M parameters makes it uniquely suited for mobile deployment.

Rapid text-to-image generation on-device.


The relative inefficiency of text-to-image diffusion models arises from two primary challenges. First, the inherent design of diffusion models requires iterative denoising to generate images, necessitating multiple evaluations of the model. Second, the complexity of the network architecture in text-to-image diffusion models involves a substantial number of parameters, regularly reaching into the billions and resulting in computationally expensive evaluations. As a result, despite the potential benefits of deploying generative models on mobile devices, such as enhancing user experience and addressing emerging privacy concerns, it remains relatively unexplored within the current literature.

The optimization of inference efficiency in text-to-image diffusion models has been an active research area. Previous studies predominantly concentrate on addressing the first challenge, seeking to reduce the number of function evaluations (NFEs). Leveraging advanced numerical solvers (e.g., DPM) or distillation techniques (e.g., progressive distillation, consistency distillation), the number of necessary sampling steps have significantly reduced from several hundreds to single digits. Some recent techniques, like DiffusionGAN and Adversarial Diffusion Distillation, even reduce to a single necessary step.

However, on mobile devices, even a small number of evaluation steps can be slow due to the complexity of model architecture. Thus far, the architectural efficiency of text-to-image diffusion models has received comparatively less attention. A handful of earlier works briefly touches upon this matter, involving the removal of redundant neural network blocks (e.g., SnapFusion). However, these efforts lack a comprehensive analysis of each component within the model architecture, thereby falling short of providing a holistic guide for designing highly efficient architectures.


Effectively overcoming the challenges imposed by the limited computational power of mobile devices requires an in-depth and holistic exploration of the model's architectural efficiency. In pursuit of this objective, our research undertakes a detailed examination of each constituent and computational operation within Stable Diffusion’s UNet architecture. We present a comprehensive guide for crafting highly efficient text-to-image diffusion models culminating in the MobileDiffusion.

The design of MobileDiffusion follows that of latent diffusion models. It contains three components: a text encoder, a diffusion UNet, and an image decoder. For the text encoder, we use CLIP-ViT/L14, which is a small model (125M parameters) suitable for mobile. We then turn our focus to the diffusion UNet and image decoder.

Diffusion UNet

As illustrated in the figure below, diffusion UNets commonly interleave transformer blocks and convolution blocks. We conduct a comprehensive investigation of these two fundamental building blocks. Throughout the study, we control the training pipeline (e.g., data, optimizer) to study the effects of different architectures.

In classic text-to-image diffusion models, a transformer block consists of a self-attention layer (SA) for modeling long-range dependencies among visual features, a cross-attention layer (CA) to capture interactions between text conditioning and visual features, and a feed-forward layer (FF) to post-process the output of attention layers. These transformer blocks hold a pivotal role in text-to-image diffusion models, serving as the primary components responsible for text comprehension. However, they also pose a significant efficiency challenge, given the computational expense of the attention operation, which is quadratic to the sequence length. We follow the idea of UViT architecture, which places more transformer blocks at the bottleneck of the UNet. This design choice is motivated by the fact that the attention computation is less resource-intensive at the bottleneck due to its lower dimensionality.

Our UNet architecture incorporates more transformers in the middle, and skips self-attention (SA) layers at higher resolutions.

Convolution blocks, in particular ResNet blocks, are deployed at each level of the UNet. While these blocks are instrumental for feature extraction and information flow, the associated computational costs, especially at high-resolution levels, can be substantial. One proven approach in this context is separable convolution. We observed that replacing regular convolution layers with lightweight separable convolution layers in the deeper segments of the UNet yields similar performance.

In the figure below, we compare the UNets of several diffusion models. Our MobileDiffusion exhibits superior efficiency in terms of FLOPs (floating-point operations) and number of parameters.

Comparison of some diffusion UNets.

Image decoder

In addition to the UNet, we also optimized the image decoder. We trained a variational autoencoder (VAE) to encode an RGB image to an 8-channel latent variable, with 8× smaller spatial size of the image. A latent variable can be decoded to an image and gets 8× larger in size. To further enhance efficiency, we design a lightweight decoder architecture by pruning the original’s width and depth. The resulting lightweight decoder leads to a significant performance boost, with nearly 50% latency improvement and better quality. For more details, please refer to our paper.

VAE reconstruction. Our VAE decoders have better visual quality than SD (Stable Diffusion).

Decoder   #Params (M)     PSNR↑     SSIM↑     LPIPS↓  
SD 49.5 26.7 0.76 0.037
Ours 39.3 30.0 0.83 0.032
Ours-Lite     9.8 30.2 0.84 0.032

Quality evaluation of VAE decoders. Our lite decoder is much smaller than SD, with better quality metrics, including peak signal-to-noise ratio (PSNR), structural similarity index measure (SSIM), and Learned Perceptual Image Patch Similarity (LPIPS).

One-step sampling

In addition to optimizing the model architecture, we adopt a DiffusionGAN hybrid to achieve one-step sampling. Training DiffusionGAN hybrid models for text-to-image generation encounters several intricacies. Notably, the discriminator, a classifier distinguishing real data and generated data, must make judgments based on both texture and semantics. Moreover, the cost of training text-to-image models can be extremely high, particularly in the case of GAN-based models, where the discriminator introduces additional parameters. Purely GAN-based text-to-image models (e.g., StyleGAN-T, GigaGAN) confront similar complexities, resulting in highly intricate and expensive training.

To overcome these challenges, we use a pre-trained diffusion UNet to initialize the generator and discriminator. This design enables seamless initialization with the pre-trained diffusion model. We postulate that the internal features within the diffusion model contain rich information of the intricate interplay between textual and visual data. This initialization strategy significantly streamlines the training.

The figure below illustrates the training procedure. After initialization, a noisy image is sent to the generator for one-step diffusion. The result is evaluated against ground truth with a reconstruction loss, similar to diffusion model training. We then add noise to the output and send it to the discriminator, whose result is evaluated with a GAN loss, effectively adopting the GAN to model a denoising step. By using pre-trained weights to initialize the generator and the discriminator, the training becomes a fine-tuning process, which converges in less than 10K iterations.

Illustration of DiffusionGAN fine-tuning.


Below we show example images generated by our MobileDiffusion with DiffusionGAN one-step sampling. With such a compact model (520M parameters in total), MobileDiffusion can generate high-quality diverse images for various domains.

Images generated by our MobileDiffusion

We measured the performance of our MobileDiffusion on both iOS and Android devices, using different runtime optimizers. The latency numbers are reported below. We see that MobileDiffusion is very efficient and can run within half a second to generate a 512x512 image. This lightning speed potentially enables many interesting use cases on mobile devices.

Latency measurements (s) on mobile devices.


With superior efficiency in terms of latency and size, MobileDiffusion has the potential to be a very friendly option for mobile deployments given its capability to enable a rapid image generation experience while typing text prompts. And we will ensure any application of this technology will be in-line with Google’s responsible AI practices.


We like to thank our collaborators and contributors that helped bring MobileDiffusion to on-device: Zhisheng Xiao, Yanwu Xu, Jiuqiang Tang, Haolin Jia, Lutz Justen, Daniel Fenner, Ronald Wotzlaw, Jianing Wei, Raman Sarokin, Juhyun Lee, Andrei Kulik, Chuo-Ling Chang, and Matthias Grundmann.

Source: Google AI Blog

MediaPipe FaceStylizer: On-device real-time few-shot face stylization

In recent years, we have witnessed rising interest across consumers and researchers in integrated augmented reality (AR) experiences using real-time face feature generation and editing functions in mobile applications, including short videos, virtual reality, and gaming. As a result, there is a growing demand for lightweight, yet high-quality face generation and editing models, which are often based on generative adversarial network (GAN) techniques. However, the majority of GAN models suffer from high computational complexity and the need for a large training dataset. In addition, it is also important to employ GAN models responsibly.

In this post, we introduce MediaPipe FaceStylizer, an efficient design for few-shot face stylization that addresses the aforementioned model complexity and data efficiency challenges while being guided by Google’s responsible AI Principles. The model consists of a face generator and a face encoder used as GAN inversion to map the image into latent code for the generator. We introduce a mobile-friendly synthesis network for the face generator with an auxiliary head that converts features to RGB at each level of the generator to generate high quality images from coarse to fine granularities. We also carefully designed the loss functions for the aforementioned auxiliary heads and combined them with the common GAN loss functions to distill the student generator from the teacher StyleGAN model, resulting in a lightweight model that maintains high generation quality. The proposed solution is available in open source through MediaPipe. Users can fine-tune the generator to learn a style from one or a few images using MediaPipe Model Maker, and deploy to on-device face stylization applications with the customized model using MediaPipe FaceStylizer.

Few-shot on-device face stylization

An end-to-end pipeline

Our goal is to build a pipeline to support users to adapt the MediaPipe FaceStylizer to different styles by fine-tuning the model with a few examples. To enable such a face stylization pipeline, we built the pipeline with a GAN inversion encoder and efficient face generator model (see below). The encoder and generator pipeline can then be adapted to different styles via a few-shot learning process. The user first sends a single or a few similar samples of the style images to MediaPipe ModelMaker to fine-tune the model. The fine-tuning process freezes the encoder module and only fine-tunes the generator. The training process samples multiple latent codes close to the encoding output of the input style images as the input to the generator. The generator is then trained to reconstruct an image of a person’s face in the style of the input style image by optimizing a joint adversarial loss function that also accounts for style and content. With such a fine-tuning process, the MediaPipe FaceStylizer can adapt to the customized style, which approximates the user’s input. It can then be applied to stylize test images of real human faces.

Generator: BlazeStyleGAN

The StyleGAN model family has been widely adopted for face generation and various face editing tasks. To support efficient on-device face generation, we based the design of our generator on StyleGAN. This generator, which we call BlazeStyleGAN, is similar to StyleGAN in that it also contains a mapping network and synthesis network. However, since the synthesis network of StyleGAN is the major contributor to the model’s high computation complexity, we designed and employed a more efficient synthesis network. The improved efficiency and generation quality is achieved by:

  1. Reducing the latent feature dimension in the synthesis network to a quarter of the resolution of the counterpart layers in the teacher StyleGAN,
  2. Designing multiple auxiliary heads to transform the downscaled feature to the image domain to form a coarse-to-fine image pyramid to evaluate the perceptual quality of the reconstruction, and
  3. Skipping all but the final auxiliary head at inference time.

With the newly designed architecture, we train the BlazeStyleGAN model by distilling it from a teacher StyleGAN model. We use a multi-scale perceptual loss and adversarial loss in the distillation to transfer the high fidelity generation capability from the teacher model to the student BlazeStyleGAN model and also to mitigate the artifacts from the teacher model.

More details of the model architecture and training scheme can be found in our paper.

Visual comparison between face samples generated by StyleGAN and BlazeStyleGAN. The images on the first row are generated by the teacher StyleGAN. The images on the second row are generated by the student BlazeStyleGAN. The face generated by BlazeStyleGAN has similar visual quality to the image generated by the teacher model. Some results demonstrate the student BlazeStyleGAN suppresses the artifacts from the teacher model in the distillation.

In the above figure, we demonstrate some sample results of our BlazeStyleGAN. By comparing with the face image generated by the teacher StyleGAN model (top row), the images generated by the student BlazeStyleGAN (bottom row) maintain high visual quality and further reduce artifacts produced by the teacher due to the loss function design in our distillation.

An encoder for efficient GAN inversion

To support image-to-image stylization, we also introduced an efficient GAN inversion as the encoder to map input images to the latent space of the generator. The encoder is defined by a MobileNet V2 backbone and trained with natural face images. The loss is defined as a combination of image perceptual quality loss, which measures the content difference, style similarity and embedding distance, as well as the L1 loss between the input images and reconstructed images.

On-device performance

We documented model complexities in terms of parameter numbers and computing FLOPs in the following table. Compared to the teacher StyleGAN (33.2M parameters), BlazeStyleGAN (generator) significantly reduces the model complexity, with only 2.01M parameters and 1.28G FLOPs for output resolution 256x256. Compared to StyleGAN-1024 (generating image size of 1024x1024), the BlazeStyleGAN-1024 can reduce both model size and computation complexity by 95% with no notable quality difference and can even suppress the artifacts from the teacher StyleGAN model.

Model     Image Size     #Params (M)     FLOPs (G)
StyleGAN     1024     33.17     74.3
BlazeStyleGAN     1024     2.07     4.70
BlazeStyleGAN     512     2.05     1.57
BlazeStyleGAN     256     2.01     1.28
Encoder     256     1.44     0.60

Model complexity measured by parameter numbers and FLOPs.

We benchmarked the inference time of the MediaPipe FaceStylizer on various high-end mobile devices and demonstrated the results in the table below. From the results, both BlazeStyleGAN-256 and BlazeStyleGAN-512 achieved real-time performance on all GPU devices. It can run in less than 10 ms runtime on a high-end phone’s GPU. BlazeStyleGAN-256 can also achieve real-time performance on the iOS devices’ CPU.

Model     BlazeStyleGAN-256 (ms)     Encoder-256 (ms)
iPhone 11     12.14     11.48
iPhone 12     11.99     12.25
iPhone 13 Pro     7.22     5.41
Pixel 6     12.24     11.23
Samsung Galaxy S10     17.01     12.70
Samsung Galaxy S20     8.95     8.20

Latency benchmark of the BlazeStyleGAN, face encoder, and the end-to-end pipeline on various mobile devices.

Fairness evaluation

The model has been trained with a high diversity dataset of human faces. The model is expected to be fair to different human faces. The fairness evaluation demonstrates the model performs good and balanced in terms of human gender, skin-tone, and ages.

Face stylization visualization

Some face stylization results are demonstrated in the following figure. The images in the top row (in orange boxes) represent the style images used to fine-tune the model. The images in the left column (in the green boxes) are the natural face images used for testing. The 2x4 matrix of images represents the output of the MediaPipe FaceStylizer which is blending outputs between the natural faces on the left-most column and the corresponding face styles on the top row. The results demonstrate that our solution can achieve high-quality face stylization for several popular styles.

Sample results of our MediaPipe FaceStylizer.

MediaPipe Solutions

The MediaPipe FaceStylizer is going to be released to public users in MediaPipe Solutions. Users can leverage MediaPipe Model Maker to train a customized face stylization model using their own style images. After training, the exported bundle of TFLite model files can be deployed to applications across platforms (Android, iOS, Web, Python, etc.) using the MediaPipe Tasks FaceStylizer API in just a few lines of code.


This work is made possible through a collaboration spanning several teams across Google. We’d like to acknowledge contributions from Omer Tov, Yang Zhao, Andrey Vakunov, Fei Deng, Ariel Ephrat, Inbar Mosseri, Lu Wang, Chuo-Ling Chang, Tingbo Hou, and Matthias Grundmann.

Source: Google AI Blog

On-device diffusion plugins for conditioned text-to-image generation

In recent years, diffusion models have shown great success in text-to-image generation, achieving high image quality, improved inference performance, and expanding our creative inspiration. Nevertheless, it is still challenging to efficiently control the generation, especially with conditions that are difficult to describe with text.

Today, we announce MediaPipe diffusion plugins, which enable controllable text-to-image generation to be run on-device. Expanding upon our prior work on GPU inference for on-device large generative models, we introduce new low-cost solutions for controllable text-to-image generation that can be plugged into existing diffusion models and their Low-Rank Adaptation (LoRA) variants.

Text-to-image generation with control plugins running on-device.


With diffusion models, image generation is modeled as an iterative denoising process. Starting from a noise image, at each step, the diffusion model gradually denoises the image to reveal an image of the target concept. Research shows that leveraging language understanding via text prompts can greatly improve image generation. For text-to-image generation, the text embedding is connected to the model via cross-attention layers. Yet, some information is difficult to describe by text prompts, e.g., the position and pose of an object. To address this problem, researchers add additional models into the diffusion to inject control information from a condition image.

Common approaches for controlled text-to-image generation include Plug-and-Play, ControlNet, and T2I Adapter. Plug-and-Play applies a widely used denoising diffusion implicit model (DDIM) inversion approach that reverses the generation process starting from an input image to derive an initial noise input, and then employs a copy of the diffusion model (860M parameters for Stable Diffusion 1.5) to encode the condition from an input image. Plug-and-Play extracts spatial features with self-attention from the copied diffusion, and injects them into the text-to-image diffusion. ControlNet creates a trainable copy of the encoder of a diffusion model, which connects via a convolution layer with zero-initialized parameters to encode conditioning information that is conveyed to the decoder layers. However, as a result, the size is large, half that of the diffusion model (430M parameters for Stable Diffusion 1.5). T2I Adapter is a smaller network (77M parameters) and achieves similar effects in controllable generation. T2I Adapter only takes the condition image as input, and its output is shared across all diffusion iterations. Yet, the adapter model is not designed for portable devices.

The MediaPipe diffusion plugins

To make conditioned generation efficient, customizable, and scalable, we design the MediaPipe diffusion plugin as a separate network that is:

  • Plugable: It can be easily connected to a pre-trained base model.
  • Trained from scratch: It does not use pre-trained weights from the base model.
  • Portable: It runs outside the base model on mobile devices, with negligible cost compared to the base model inference.
Method    Parameter Size     Plugable     From Scratch     Portable
Plug-and-Play    860M*     ✔️        
ControlNet    430M*     ✔️        
T2I Adapter    77M     ✔️     ✔️    
MediaPipe Plugin    6M     ✔️     ✔️     ✔️

Comparison of Plug-and-Play, ControlNet, T2I Adapter, and the MediaPipe diffusion plugin.
* The number varies depending on the particulars of the diffusion model.

The MediaPipe diffusion plugin is a portable on-device model for text-to-image generation. It extracts multiscale features from a conditioning image, which are added to the encoder of a diffusion model at corresponding levels. When connecting to a text-to-image diffusion model, the plugin model can provide an extra conditioning signal to the image generation. We design the plugin network to be a lightweight model with only 6M parameters. It uses depth-wise convolutions and inverted bottlenecks from MobileNetv2 for fast inference on mobile devices.

Overview of the MediaPipe diffusion model plugin. The plugin is a separate network, whose output can be plugged into a pre-trained text-to-image generation model. Features extracted by the plugin are applied to the associated downsampling layer of the diffusion model (blue).

Unlike ControlNet, we inject the same control features in all diffusion iterations. That is, we only run the plugin once for one image generation, which saves computation. We illustrate some intermediate results of a diffusion process below. The control is effective at every diffusion step and enables controlled generation even at early steps. More iterations improve the alignment of the image with the text prompt and generate more detail.

Illustration of the generation process using the MediaPipe diffusion plugin.


In this work, we developed plugins for a diffusion-based text-to-image generation model with MediaPipe Face Landmark, MediaPipe Holistic Landmark, depth maps, and Canny edge. For each task, we select about 100K images from a web-scale image-text dataset, and compute control signals using corresponding MediaPipe solutions. We use refined captions from PaLI for training the plugins.

Face Landmark

The MediaPipe Face Landmarker task computes 478 landmarks (with attention) of a human face. We use the drawing utils in MediaPipe to render a face, including face contour, mouth, eyes, eyebrows, and irises, with different colors. The following table shows randomly generated samples by conditioning on face mesh and prompts. As a comparison, both ControlNet and Plugin can control text-to-image generation with given conditions.

Face-landmark plugin for text-to-image generation, compared with ControlNet.

Holistic Landmark

MediaPipe Holistic Landmarker task includes landmarks of body pose, hands, and face mesh. Below, we generate various stylized images by conditioning on the holistic features.

Holistic-landmark plugin for text-to-image generation.


Depth-plugin for text-to-image generation.

Canny Edge

Canny-edge plugin for text-to-image generation.


We conduct a quantitative study of the face landmark plugin to demonstrate the model's performance. The evaluation dataset contains 5K human images. We compare the generation quality as measured by the widely used metrics, Fréchet Inception Distance (FID) and CLIP scores. The base model is a pre-trained text-to-image diffusion model. We use Stable Diffusion v1.5 here.

As shown in the following table, both ControlNet and the MediaPipe diffusion plugin produce much better sample quality than the base model, in terms of FID and CLIP scores. Unlike ControlNet, which needs to run at every diffusion step, the MediaPipe plugin only runs once for each image generated. We measured the performance of the three models on a server machine (with Nvidia V100 GPU) and a mobile phone (Galaxy S23). On the server, we run all three models with 50 diffusion steps, and on mobile, we run 20 diffusion steps using the MediaPipe image generation app. Compared with ControlNet, the MediaPipe plugin shows a clear advantage in inference efficiency while preserving the sample quality.

Model     FID↓     CLIP↑     Inference Time (s)
Nvidia V100     Galaxy S23
Base     10.32     0.26     5.0     11.5
Base + ControlNet     6.51     0.31     7.4 (+48%)     18.2 (+58.3%)
Base + MediaPipe Plugin     6.50     0.30     5.0 (+0.2%)     11.8 (+2.6%)

Quantitative comparison on FID, CLIP, and inference time.

We test the performance of the plugin on a wide range of mobile devices from mid-tier to high-end. We list the results on some representative devices in the following table, covering both Android and iOS.

Device     Android     iOS
    Pixel 4     Pixel 6     Pixel 7     Galaxy S23     iPhone 12 Pro     iPhone 13 Pro
Time (ms)     128     68     50     48     73     63

Inference time (ms) of the plugin on different mobile devices.


In this work, we present MediaPipe, a portable plugin for conditioned text-to-image generation. It injects features extracted from a condition image to a diffusion model, and consequently controls the image generation. Portable plugins can be connected to pre-trained diffusion models running on servers or devices. By running text-to-image generation and plugins fully on-device, we enable more flexible applications of generative AI.


We’d like to thank all team members who contributed to this work: Raman Sarokin and Juhyun Lee for the GPU inference solution; Khanh LeViet, Chuo-Ling Chang, Andrei Kulik, and Matthias Grundmann for leadership. Special thanks to Jiuqiang Tang, Joe Zou and Lu wang, who made this technology and all the demos running on-device.

Source: Google AI Blog

Distributed differential privacy for federated learning

Federated learning is a distributed way of training machine learning (ML) models where data is locally processed and only focused model updates and metrics that are intended for immediate aggregation are shared with a server that orchestrates training. This allows the training of models on locally available signals without exposing raw data to servers, increasing user privacy. In 2021, we announced that we are using federated learning to train Smart Text Selection models, an Android feature that helps users select and copy text easily by predicting what text they want to select and then automatically expanding the selection for them.

Since that launch, we have worked to improve the privacy guarantees of this technology by carefully combining secure aggregation (SecAgg) and a distributed version of differential privacy. In this post, we describe how we built and deployed the first federated learning system that provides formal privacy guarantees to all user data before it becomes visible to an honest-but-curious server, meaning a server that follows the protocol but could try to gain insights about users from data it receives. The Smart Text Selection models trained with this system have reduced memorization by more than two-fold, as measured by standard empirical testing methods.

Scaling secure aggregation

Data minimization is an important privacy principle behind federated learning. It refers to focused data collection, early aggregation, and minimal data retention required during training. While every device participating in a federated learning round computes a model update, the orchestrating server is only interested in their average. Therefore, in a world that optimizes for data minimization, the server would learn nothing about individual updates and only receive an aggregate model update. This is precisely what the SecAgg protocol achieves, under rigorous cryptographic guarantees.

Important to this work, two recent advancements have improved the efficiency and scalability of SecAgg at Google:

  • An improved cryptographic protocol: Until recently, a significant bottleneck in SecAgg was client computation, as the work required on each device scaled linearly with the total number of clients (N) participating in the round. In the new protocol, client computation now scales logarithmically in N. This, along with similar gains in server costs, results in a protocol able to handle larger rounds. Having more users participate in each round improves privacy, both empirically and formally.
  • Optimized client orchestration: SecAgg is an interactive protocol, where participating devices progress together. An important feature of the protocol is that it is robust to some devices dropping out. If a client does not send a response in a predefined time window, then the protocol can continue without that client’s contribution. We have deployed statistical methods to effectively auto-tune such a time window in an adaptive way, resulting in improved protocol throughput.

The above improvements made it easier and faster to train Smart Text Selection with stronger data minimization guarantees.

Aggregating everything via secure aggregation

A typical federated training system not only involves aggregating model updates but also metrics that describe the performance of the local training. These are important for understanding model behavior and debugging potential training issues. In federated training for Smart Text Selection, all model updates and metrics are aggregated via SecAgg. This behavior is statically asserted using TensorFlow Federated, and locally enforced in Android’s Private Compute Core secure environment. As a result, this enhances privacy even more for users training Smart Text Selection, because unaggregated model updates and metrics are not visible to any part of the server infrastructure.

Differential privacy

SecAgg helps minimize data exposure, but it does not necessarily produce aggregates that guarantee against revealing anything unique to an individual. This is where differential privacy (DP) comes in. DP is a mathematical framework that sets a limit on an individual's influence on the outcome of a computation, such as the parameters of a ML model. This is accomplished by bounding the contribution of any individual user and adding noise during the training process to produce a probability distribution over output models. DP comes with a parameter (ε) that quantifies how much the distribution could change when adding or removing the training examples of any individual user (the smaller the better).

Recently, we announced a new method of federated training that enforces formal and meaningfully strong DP guarantees in a centralized manner, where a trusted server controls the training process. This protects against external attackers who may attempt to analyze the model. However, this approach still relies on trust in the central server. To provide even greater privacy protections, we have created a system that uses distributed differential privacy (DDP) to enforce DP in a distributed manner, integrated within the SecAgg protocol.

Distributed differential privacy

DDP is a technology that offers DP guarantees with respect to an honest-but-curious server coordinating training. It works by having each participating device clip and noise its update locally, and then aggregating these noisy clipped updates through the new SecAgg protocol described above. As a result, the server only sees the noisy sum of the clipped updates.

However, the combination of local noise addition and use of SecAgg presents significant challenges in practice:

  • An improved discretization method: One challenge is properly representing model parameters as integers in SecAgg's finite group with integer modular arithmetic, which can inflate the norm of the discretized model and require more noise for the same privacy level. For example, randomized rounding to the nearest integers could inflate the user's contribution by a factor equal to the number of model parameters. We addressed this by scaling the model parameters, applying a random rotation, and rounding to nearest integers. We also developed an approach for auto-tuning the discretization scale during training. This led to an even more efficient and accurate integration between DP and SecAgg.
  • Optimized discrete noise addition: Another challenge is devising a scheme for choosing an arbitrary number of bits per model parameter without sacrificing end-to-end privacy guarantees, which depend on how the model updates are clipped and noised. To address this, we added integer noise in the discretized domain and analyzed the DP properties of sums of integer noise vectors using the distributed discrete Gaussian and distributed Skellam mechanisms.
An overview of federated learning with distributed differential privacy.

We tested our DDP solution on a variety of benchmark datasets and in production and validated that we can match the accuracy to central DP with a SecAgg finite group of size 12 bits per model parameter. This meant that we were able to achieve added privacy advantages while also reducing memory and communication bandwidth. To demonstrate this, we applied this technology to train and launch Smart Text Selection models. This was done with an appropriate amount of noise chosen to maintain model quality. All Smart Text Selection models trained with federated learning now come with DDP guarantees that apply to both the model updates and metrics seen by the server during training. We have also open sourced the implementation in TensorFlow Federated.

Empirical privacy testing

While DDP adds formal privacy guarantees to Smart Text Selection, those formal guarantees are relatively weak (a finite but large ε, in the hundreds). However, any finite ε is an improvement over a model with no formal privacy guarantee for several reasons: 1) A finite ε moves the model into a regime where further privacy improvements can be quantified; and 2) even large ε’s can indicate a substantial decrease in the ability to reconstruct training data from the trained model. To get a more concrete understanding of the empirical privacy advantages, we performed thorough analyses by applying the Secret Sharer framework to Smart Text Selection models. Secret Sharer is a model auditing technique that can be used to measure the degree to which models unintentionally memorize their training data.

To perform Secret Sharer analyses for Smart Text Selection, we set up control experiments which collect gradients using SecAgg. The treatment experiments use distributed differential privacy aggregators with different amounts of noise.

We found that even low amounts of noise reduce memorization meaningfully, more than doubling the Secret Sharer rank metric for relevant canaries compared to the baseline. This means that even though the DP ε is large, we empirically verified that these amounts of noise already help reduce memorization for this model. However, to further improve on this and to get stronger formal guarantees, we aim to use even larger noise multipliers in the future.

Next steps

We developed and deployed the first federated learning and distributed differential privacy system that comes with formal DP guarantees with respect to an honest-but-curious server. While offering substantial additional protections, a fully malicious server might still be able to get around the DDP guarantees either by manipulating the public key exchange of SecAgg or by injecting a sufficient number of "fake" malicious clients that don’t add the prescribed noise into the aggregation pool. We are excited to address these challenges by continuing to strengthen the DP guarantee and its scope.


The authors would like to thank Adria Gascon for significant impact on the blog post itself, as well as the people who helped develop these ideas and bring them to practice: Ken Liu, Jakub Konečný, Brendan McMahan, Naman Agarwal, Thomas Steinke, Christopher Choquette, Adria Gascon, James Bell, Zheng Xu, Asela Gunawardana, Kallista Bonawitz, Mariana Raykova, Stanislav Chiknavaryan, Tancrède Lepoint, Shanshan Wu, Yu Xiao, Zachary Charles, Chunxiang Zheng, Daniel Ramage, Galen Andrew, Hugo Song, Chang Li, Sofia Neata, Ananda Theertha Suresh, Timon Van Overveldt, Zachary Garrett, Wennan Zhu, and Lukas Zilka. We’d also like to thank Tom Small for creating the animated figure.

Source: Google AI Blog

Who Said What? Recorder’s On-device Solution for Labeling Speakers

In 2019 we launched Recorder, an audio recording app for Pixel phones that helps users create, manage, and edit audio recordings. It leverages recent developments in on-device machine learning to transcribe speech, recognize audio events, suggest tags for titles, and help users navigate transcripts.

Nonetheless, some Recorder users found it difficult to navigate long recordings that have multiple speakers because it's not clear who said what. During the Made By Google event this year, we announced the "speaker labels" feature for the Recorder app. This opt-in feature annotates a recording transcript with unique and anonymous labels for each speaker (e.g., "Speaker 1", "Speaker 2", etc.) in real time during the recording. It significantly improves the readability and usability of the recording transcripts. This feature is powered by Google's new speaker diarization system named Turn-to-Diarize, which was first presented at ICASSP 2022.

Left: Recorder transcript without speaker labels. Right: Recorder transcript with speaker labels.

System Architecture

Our speaker diarization system leverages several highly optimized machine learning models and algorithms to allow diarizing hours of audio in a real-time streaming fashion with limited computational resources on mobile devices. The system mainly consists of three components: a speaker turn detection model that detects a change of speaker in the input speech, a speaker encoder model that extracts voice characteristics from each speaker turn, and a multi-stage clustering algorithm that annotates speaker labels to each speaker turn in a highly efficient way. All components run fully on the device.

Architecture of the Turn-to-Diarize system.

Detecting Speaker Turns

The first component of our system is a speaker turn detection model based on a Transformer Transducer (T-T), which converts the acoustic features into text transcripts augmented with a special token <st> representing a speaker turn. Unlike preceding customized systems that use role-specific tokens (e.g., <doctor> and <patient>) for conversations, this model is more generic and can be trained on and deployed to various application domains.

In most applications, the output of a diarization system is not directly shown to users, but combined with a separate automatic speech recognition (ASR) system that is trained to have smaller word errors. Therefore, for the diarization system, we are relatively more tolerant to word token errors than errors of the <st> token. Based on this intuition, we propose a new token-level loss function that allows us to train a small speaker turn detection model with high accuracy on predicted <st> tokens. Combined with edit-based minimum Bayes risk (EMBR) training, this new loss function significantly improved the interval-based F1 score on seven evaluation datasets.

Extracting Voice Characteristics

Once the audio recording has been segmented into homogeneous speaker turns, we use a speaker encoder model to extract an embedding vector (i.e., d-vector) to represent the voice characteristics of each speaker turn. This approach has several advantages over prior work that extracts embedding vectors from small fixed-length segments. First, it avoids extracting an embedding from a segment containing speech from multiple speakers. At the same time, each embedding covers a relatively large time range that contains sufficient signals from the speaker. It also reduces the total number of embeddings to be clustered, thus making the clustering step less expensive. These embeddings are processed entirely on-device until speaker labeling of the transcript is completed, and then deleted.

Multi-Stage Clustering

After the audio recording is represented by a sequence of embedding vectors, the last step is to cluster these embedding vectors, and assign a speaker label to each. However, since audio recordings from the Recorder app can be as short as a few seconds, or as long as up to 18 hours, it is critical for the clustering algorithm to handle sequences of drastically different lengths.

For this we propose a multi-stage clustering strategy to leverage the benefits of different clustering algorithms. First, we use the speaker turn detection outputs to determine whether there are at least two different speakers in the recording. For short sequences, we use agglomerative hierarchical clustering (AHC) as the fallback algorithm. For medium-length sequences, we use spectral clustering as our main algorithm, and use the eigen-gap criterion for accurate speaker count estimation. For long sequences, we reduce computational cost by using AHC to pre-cluster the sequence before feeding it to the main algorithm. During the streaming, we keep a dynamic cache of previous AHC cluster centroids that can be reused for future clustering calls. This mechanism allows us to enforce an upper bound on the entire system with constant time and space complexity.

This multi-stage clustering strategy is a critical optimization for on-device applications where the budget for CPU, memory, and battery is very small, and allows the system to run in a low power mode even after diarizing hours of audio. As a tradeoff between quality and efficiency, the upper bound of the computational cost can be flexibly configured for devices with different computational resources.

Diagram of the multi-stage clustering strategy.

Correction and Customization

In our real-time streaming speaker diarization system, as the model consumes more audio input, it accumulates confidence on predicted speaker labels, and may occasionally make corrections to previously predicted low-confidence speaker labels. The Recorder app automatically updates the speaker labels on the screen during recording to reflect the latest and most accurate predictions.

At the same time, the Recorder app's UI allows the user to rename the anonymous speaker labels (e.g., "Speaker 2") to customized labels (e.g., "car dealer") for better readability and easier memorization for the user within each recording.

Recorder allows the user to rename the speaker labels for better readability.

Future Work

Currently, our diarization system mostly runs on the CPU block of Google Tensor, Google's custom-built chip that powers more recent Pixel phones. We are working on delegating more computations to the TPU block, which will further reduce the overall power consumption of the diarization system. Another future work direction is to leverage multilingual capabilities of speaker encoder and speech recognition models to expand this feature to more languages.


The work described in this post represents joint efforts from multiple teams within Google. Contributors include Quan Wang, Yiling Huang, Evan Clark, Qi Cao, Han Lu, Guanlong Zhao, Wei Xia, Hasim Sak, Alvin Zhou, Jason Pelecanos, Luiza Timariu, Allen Su, Fan Zhang, Hugh Love, Kristi Bradford, Vincent Peng, Raff Tsai, Richard Chou, Yitong Lin, Ann Lu, Kelly Tsai, Hannah Bowman, Tracy Wu, Taral Joglekar, Dharmesh Mokani, Ajay Dudani, Ignacio Lopez Moreno, Diego Melendo Casado, Nino Tasca, Alex Gruenstein.

Source: Google AI Blog

Efficient Sequence Modeling for On-Device ML

The increasing demand for machine learning (ML) model inference on-device (for mobile devices, tablets, etc.) is driven by the rise of compute-intensive applications, the need to keep certain data on device for privacy and security reasons, and the desire to provide services when a network connection may not be available. However, on-device inference introduces a myriad of challenges, ranging from modeling to platform support requirements. These challenges relate to how different architectures are designed to optimize memory and computation, while still trying to maintain the quality of the model. From a platform perspective, the issue is identifying operations and building on top of them in a way that can generalize well across different product use cases.

In previous research, we combined a novel technique for generating embeddings (called projection-based embeddings) with efficient architectures like QRNN (pQRNN) and proved them to be competent for a number of classification problems. Augmenting these with distillation techniques provides an additional bump in end-to-end quality. Although this is an effective approach, it is not scalable to bigger and more extensive vocabularies (i.e., all possible Unicode or word tokens that can be fed to the model). Additionally, the output from the projection operation itself doesn’t contain trainable weights to take advantage of pre-training the model.

Token-free models presented in ByT5 are a good starting point for on-device modeling that can address pre-training and scalability issues without the need to increase the size of the model. This is possible because these approaches treat text inputs as a stream of bytes (each byte has a value that ranges from 0 to 255) that can reduce the vocabulary size for the embedding tables from ~30,000 to 256. Although ByT5 presents a compelling alternative for on-device modeling, going from word-level representation to byte stream representation increases the sequence lengths linearly; with an average word length of four characters and a single character having up to four bytes, the byte sequence length increases proportionally to the word length. This can lead to a significant increase in inference latency and computational costs.

We address this problem by developing and releasing three novel byte-stream sequence models for the SeqFlowLite library (ByteQRNN, ByteTransformer and ByteFunnelTransformer), all of which can be pre-trained on unsupervised data and can be fine-tuned for specific tasks. These models leverage recent innovations introduced by Charformer, including a fast character Transformer-based model that uses a gradient-based subword tokenization (GBST) approach to operate directly at the byte level, as well as a “soft” tokenization approach, which allows us to learn token boundaries and reduce sequence lengths. In this post, we focus on ByteQRNN and demonstrate that the performance of a pre-trained ByteQRNN model is comparable to BERT, despite being 300x smaller.

Sequence Model Architecture
We leverage pQRNN, ByT5 and Charformer along with platform optimizations, such as in-training quantization (which tracks minimum and maximum float values for model activations and weights for quantizing the inference model) that reduces model sizes by one-fourth, to develop an end-to-end model called ByteQRNN (shown below). First, we use a ByteSplitter operation to split the input string into a byte stream and feed it to a smaller embedding table that has a vocabulary size of 259 (256 + 3 additional meta tokens).

The output from the embedding layer is fed to the GBST layer, which is equipped with in-training quantization and combines byte-level representations with the efficiency of subword tokenization while enabling end-to-end learning of latent subwords. We “soft” tokenize the byte stream sequences by enumerating and combining each subword block length with scores (computed with a quantized dense layer) at each strided token position (i.e., at token positions that are selected at regular intervals). Next, we downsample the byte stream to manageable sequence length and feed it to the encoder layer.

The output from the GBST layer can be downsampled to a lower sequence length for efficient encoder computation or can be used by an encoder, like Funnel Transformer, which pools the query length and reduces the self-attention computation to create the ByteFunnelTransformer model. The encoder in the end-to-end model can be replaced with any other encoder layer, such as the Transformer from the SeqFlowLite library, to create a ByteTransformer model.

A diagram of a generic end-to-end sequence model using byte stream input. The ByteQRNN model uses a QRNN encoder from the SeqFlowLite library.

In addition to the input embeddings (i.e., the output from the embedding layer described above), we go a step further to build an effective sequence-to-sequence (seq2seq) model. We do so by taking ByteQRNN and adding a Transformer-based decoder model along with a quantized beam search (or tree exploration) to go with it. The quantized beam search module reduces the inference latency when generating decoder outputs by computing the most likely beams (i.e., possible output sequences) using the logarithmic sum of previous and current probabilities and returns the resulting top beams. Here the system uses a more efficient 8-bit integer (uint8) format, compared to a typical single-precision floating-point format (float32) model.

The decoder Transformer model uses a merged attention sublayer (MAtt) to reduce the complexity of the decoder self-attention from quadratic to linear, thereby lowering the end-to-end latency. For each decoding step, MAtt uses a fixed-size cache for decoder self-attention compared to the increasing cache size of a traditional transformer decoder. The following figure illustrates how the beam search module interacts with the decoder layer to generate output tokens on-device using an edge device (e.g., mobile phones, tablets, etc.).

A comparison of cloud server decoding and on-device (edge device) implementation. Left: Cloud server beam search employs a Transformer-based decoder model with quadratic time self-attention in float32, which has an increasing cache size for each decoding step. Right: The edge device implementation employs a quantized beam search module along with a fixed-size cache and a linear time self-attention computation.

After developing ByteQRNN, we evaluate its performance on the civil_comments dataset using the area under the curve (AUC) metric and compare it to a pre-trained ByteQRNN and BERT (shown below). We demonstrate that the fine-tuned ByteQRNN improves the overall quality and brings its performance closer to the BERT models, despite being 300x smaller. Since SeqFlowLite models support in-training quantization that reduces model sizes by one-fourth, the resulting models scale well to low-compute devices. We chose multilingual data sources that related to the task for pre-training both BERT and byte stream models to achieve the best possible performance.

Comparison of ByteQRNN with fine-tuned ByteQRNN and BERT on the civil_comments dataset.

Following up on our previous work with pQRNN, we evaluate byte stream models for on-device use to enable pre-training and thereby improve model performance for on-device deployment. We present an evaluation for ByteQRNN with and without pre-training and demonstrate that the performance of the pre-trained ByteQRNN is comparable to BERT, despite being 300x smaller. In addition to ByteQRNN, we are also releasing ByteTransformer and ByteFunnelTransformer, two models which use different encoders, along with the merged attention decoder model and the beam search driver to run the inference through the SeqFlowLite library. We hope these models will provide researchers and product developers with valuable resources for future on-device deployments.

We would like to thank Khoa Trinh, Jeongwoo Ko, Peter Young and Yicheng Fan for helping with open-sourcing and evaluating the model. Thanks to Prabhu Kaliamoorthi for all the brainstorming and ideation. Thanks to Vinh Tran, Jai Gupta and Yi Tay for their help with pre-training byte stream models. Thanks to Ruoxin Sang, Haoyu Zhang, Ce Zheng, Chuanhao Zhuge and Jieying Luo for helping with the TPU training. Many thanks to Erik Vee, Ravi Kumar and the Learn2Compress leadership for sponsoring the project and their support and encouragement. Finally, we would like to thank Tom Small for the animated figure used in this post.

Source: Google AI Blog

TRILLsson: Small, Universal Speech Representations for Paralinguistic Tasks

In recent years, we have seen dramatic improvements on lexical tasks such as automatic speech recognition (ASR). However, machine systems still struggle to understand paralinguistic aspects — such as tone, emotion, whether a speaker is wearing a mask, etc. Understanding these aspects represents one of the remaining difficult problems in machine hearing. In addition, state-of-the-art results often come from ultra-large models trained on private data, making them impractical to run on mobile devices or to release publicly.

In “Universal Paralinguistic Speech Representations Using Self-Supervised Conformers”, to appear in ICASSP 2022, we introduce CAP12— the 12th layer of a 600M parameter model trained on the YT-U training dataset using self-supervision. We demonstrate that the CAP12 model outperforms nearly all previous results in our paralinguistic benchmark, sometimes by large margins, even though previous results are often task-specific. In “TRILLsson: Distilled Universal Paralinguistic Speech Representations'', we introduce the small, performant, publicly-available TRILLsson models and demonstrate how we reduced the size of the high-performing CAP12 model by 6x-100x while maintaining 90-96% of the performance. To create TRILLsson, we apply knowledge distillation on appropriately-sized audio chunks and use different architecture types to train smaller, faster networks that are small enough to run on mobile devices.

1M-Hour Dataset to Train Ultra-Large Self-Supervised Models
We leverage the YT-U training dataset to train the ultra-large, self-supervised CAP12 model. The YT-U dataset is a highly varied, 900M+ hour dataset that contains audio of various topics, background conditions, and speaker acoustic properties.

Video categories by length (outer) and number (inner), demonstrating the variety in the YT-U dataset (figure from BigSSL)

We then modify a Wav2Vec 2.0 self-supervised training paradigm, which can solve tasks using raw data without labels, and combine it with ultra-large Conformer models. Because self-training doesn't require labels, we can take full advantage of YT-U by scaling up our models to some of the largest model sizes ever trained, including 600M, 1B, and 8B parameters.

NOSS: A Benchmark for Paralinguistic Tasks
We demonstrate that an intermediate representation of one of the previous models contains a state-of-the-art representation for paralinguistic speech. We call the 600M parameter Conformer model without relative attention Conformer Applied to Paralinguistics (CAP). We exhaustively search through all intermediate representations of six ultra-large models and find that layer 12 (CAP12) outperforms previous representations by significant margins.

To measure the quality of the roughly 300 candidate paralinguistic speech representations, we evaluate on an expanded version of the NOn-Semantic Speech (NOSS) benchmark, which is a collection of well-studied paralinguistic speech tasks, such as speech emotion recognition, language identification, and speaker identification. These tasks focus on paralinguistics aspects of speech, which require evaluating speech features on the order of 1 second or longer, rather than lexical features, which require 100ms or shorter. We then add to the benchmark a mask-wearing task introduced at Interspeech 2020, a fake speech detection task (ASVSpoof 2019), a task to detect the level of dysarthria from project Euphonia, and an additional speech emotion recognition task (IEMOCAP). By expanding the benchmark and increasing the diversity of the tasks, we empirically demonstrate that CAP12 is even more generally useful than previous representations.

Simple linear models on time-averaged CAP12 representations even outperform complex, task-specific models on five out of eight paralinguistic tasks. This is surprising because comparable models sometimes use additional modalities (e.g., vision and speech, or text and speech) as well. Furthermore, CAP12 is exceptionally good at emotion recognition tasks. CAP12 embeddings also outperform all other embeddings on all other tasks with only a single exception: for one embedding from a supervised network on the dysarthria detection task.

Model Voxceleb   Voxforge   Speech Commands   ASVSpoof2019∗∗   Euphonia#   CREMA-D   IEMOCAP
Prev SoTA - 95.4 97.9 5.11 45.9 74.0 67.6+
TRILL 12.6 84.5 77.6 74.6 48.1 65.7 54.3
ASR Embedding 5.2 98.9 96.1 11.2 54.5 71.8 65.4
Wav2Vec2 layer 6†† 17.9 98.5 95.0 6.7 48.2 77.4 65.8
CAP12 51.0 99.7 97.0 2.5 51.5 88.2 75.0
Test performance on the NOSS Benchmark and extended tasks. “Prev SoTA” indicates the previous best performing state-of-the-art model, which has arbitrary complexity, but all other rows are linear models on time-averaged input. Filtered according to YouTube’s privacy guidelines. ∗∗ Uses equal error rate [20]. # The only non-public dataset. We exclude it from aggregate scores. Audio and visual features used in previous state-of-the-art models. + The previous state-of-the-art model performed cross-validation. For our evaluation, we hold out two specific speakers as a test. †† Wav2Vec 2.0 model from HuggingFace. Best overall layer was layer 6.

TRILLsson: Small, High Quality, Publicly Available Models
Similar to FRILL, our next step was to make an on-device, publicly available version of CAP12. This involved using knowledge distillation to train smaller, faster, mobile-friendly architectures. We experimented with EfficientNet, Audio Spectrogram Transformer (AST), and ResNet. These model types are very different, and cover both fixed-length and arbitrary-length inputs. EfficientNet comes from a neural architecture search over vision models to find simultaneously performant and efficient model structures. AST models are transformers adapted to audio inputs. ResNet is a standard architecture that has shown good performance across many different models.

We trained models that performed on average 90-96% as well as CAP12, despite being 1%-15% the size and trained using only 6% the data. Interestingly, we found that different architecture types performed better at different sizes. ResNet models performed best at the low end, EfficientNet in the middle, and AST models at the larger end.

Aggregate embedding performance vs. model size for various student model architectures and sizes. We demonstrate that ResNet architectures perform best for small sizes, EfficientNetV2 performs best in the midsize model range, up to the largest model size tested, after which the larger AST models are best.

We perform knowledge distillation with the goal of matching a student, with a fixed-size input, to the output of a teacher, with a variable-size input, for which there are two methods of generating student targets: global matching and local matching. Global matching produces distillation targets by generating CAP12 embeddings for an entire audio clip, and then requires that a student match the target from just a small segment of audio (e.g., 2 seconds). Local matching requires that the student network match the average CAP12 embedding just over the smaller portion of the audio that the student sees. In our work, we focused on local matching.

Two types of generating distillation targets for sequences. Left: Global matching uses the average CAP12 embedding over the whole clip for the target for each local chunk. Right: Local matching uses CAP12 embeddings averaged just over local clips as the distillation target.

Observation of Bimodality and Future Directions
Paralinguistic information shows an unexpected bimodal distribution. For the CAP model that operates on 500 ms input segments, and two of the full-input Conformer models, intermediate representations gradually increase in paralinguistic information, then decrease, then increase again, and finally lose this information towards the output layer. Surprisingly, this pattern is also seen when exploring the intermediate representations of networks trained on retinal images.

500 ms inputs to CAP show a relatively pronounced bimodal distribution of paralinguistic information across layers.
Two of the conformer models with full inputs show a bimodal distribution of paralinguistic information across layers.

We hope that smaller, faster models for paralinguistic speech unlock new applications in speech recognition, text-to-speech generation, and understanding user intent. We also expect that smaller models will be more easily interpretable, which will allow researchers to understand what aspects of speech are important for paralinguistics. Finally, we hope that our open-sourced speech representations are used by the community to improve paralinguistic speech tasks and user understanding in private or small datasets.

I'd like to thank my co-authors Aren Jansen, Wei Han, Daniel Park, Yu Zhang, and Subhashini Venugopalan for their hard work and creativity on this project. I'd also like to thank the members of the large collaboration for the BigSSL work, without which these projects would not be possible. The team includes James Qin, Anmol Gulati, Yuanzhong Xu, Yanping Huang, Shibo Wang, Zongwei Zhou, Bo Li, Min Ma, William Chan, Jiahui Yu, Yongqiang Wang, Liangliang Cao, Khe Chai Sim, Bhuvana Ramabhadran, Tara N. Sainath, Françoise Beaufays, Zhifeng Chen, Quoc V. Le, Chung-Cheng Chiu, Ruoming Pang, and Yonghui Wu.

Source: Google AI Blog

Federated Learning with Formal Differential Privacy Guarantees

In 2017, Google introduced federated learning (FL), an approach that enables mobile devices to collaboratively train machine learning (ML) models while keeping the raw training data on each user's device, decoupling the ability to do ML from the need to store the data in the cloud. Since its introduction, Google has continued to actively engage in FL research and deployed FL to power many features in Gboard, including next word prediction, emoji suggestion and out-of-vocabulary word discovery. Federated learning is improving the “Hey Google” detection models in Assistant, suggesting replies in Google Messages, predicting text selections, and more.

While FL allows ML without raw data collection, differential privacy (DP) provides a quantifiable measure of data anonymization, and when applied to ML can address concerns about models memorizing sensitive user data. This too has been a top research priority, and has yielded one of the first production uses of DP for analytics with RAPPOR in 2014, our open-source DP library, Pipeline DP, and TensorFlow Privacy.

Through a multi-year, multi-team effort spanning fundamental research and product integration, today we are excited to announce that we have deployed a production ML model using federated learning with a rigorous differential privacy guarantee. For this proof-of-concept deployment, we utilized the DP-FTRL algorithm to train a recurrent neural network to power next-word-prediction for Spanish-language Gboard users. To our knowledge, this is the first production neural network trained directly on user data announced with a formal DP guarantee (technically ρ=0.81 zero-Concentrated-Differential-Privacy, zCDP, discussed in detail below). Further, the federated approach offers complimentary data minimization advantages, and the DP guarantee protects all of the data on each device, not just individual training examples.

Data Minimization and Anonymization in Federated Learning
Along with fundamentals like transparency and consent, the privacy principles of data minimization and anonymization are important in ML applications that involve sensitive data.

Federated learning systems structurally incorporate the principle of data minimization. FL only transmits minimal updates for a specific model training task (focused collection), limits access to data at all stages, processes individuals’ data as early as possible (early aggregation), and discards both collected and processed data as soon as possible (minimal retention).

Another principle that is important for models trained on user data is anonymization, meaning that the final model should not memorize information unique to a particular individual's data, e.g., phone numbers, addresses, credit card numbers. However, FL on its own does not directly tackle this problem.

The mathematical concept of DP allows one to formally quantify this principle of anonymization. Differentially private training algorithms add random noise during training to produce a probability distribution over output models, and ensure that this distribution doesn't change too much given a small change to the training data; ρ-zCDP quantifies how much the distribution could possibly change. We call this example-level DP when adding or removing a single training example changes the output distribution on models in a provably minimal way.

Showing that deep learning with example-level differential privacy was even possible in the simpler setting of centralized training was a major step forward in 2016. Achieved by the DP-SGD algorithm, the key was amplifying the privacy guarantee by leveraging the randomness in sampling training examples ("amplification-via-sampling").

However, when users can contribute multiple examples to the training dataset, example-level DP is not necessarily strong enough to ensure the users’ data isn't memorized. Instead, we have designed algorithms for user-level DP, which requires that the output distribution of models doesn't change even if we add/remove all of the training examples from any one user (or all the examples from any one device in our application). Fortunately, because FL summarizes all of a user's training data as a single model update, federated algorithms are well-suited to offering user-level DP guarantees.

Both limiting the contributions from one user and adding noise can come at the expense of model accuracy, however, so maintaining model quality while also providing strong DP guarantees is a key research focus.

The Challenging Path to Federated Learning with Differential Privacy
In 2018, we introduced the DP-FedAvg algorithm, which extended the DP-SGD approach to the federated setting with user-level DP guarantees, and in 2020 we deployed this algorithm to mobile devices for the first time. This approach ensures the training mechanism is not too sensitive to any one user's data, and empirical privacy auditing techniques rule out some forms of memorization.

However, the amplification-via-samping argument is essential to providing a strong DP guarantee for DP-FedAvg, but in a real-world cross-device FL system ensuring devices are subsampled precisely and uniformly at random from a large population would be complex and hard to verify. One challenge is that devices choose when to connect (or "check in") based on many external factors (e.g., requiring the device is idle, on unmetered WiFi, and charging), and the number of available devices can vary substantially.

Achieving a formal privacy guarantee requires a protocol that does all of the following:

  • Makes progress on training even as the set of devices available varies significantly with time.
  • Maintains privacy guarantees even in the face of unexpected or arbitrary changes in device availability.
  • For efficiency, allows client devices to locally decide whether they will check in to the server in order to participate in training, independent of other devices.

Initial work on privacy amplification via random check-ins highlighted these challenges and introduced a feasible protocol, but it would have required complex changes to our production infrastructure to deploy. Further, as with the amplification-via-sampling analysis of DP-SGD, the privacy amplification possible with random check-ins depends on a large number of devices being available. For example, if only 1000 devices are available for training, and participation of at least 1000 devices is needed in each training step, that requires either 1) including all devices currently available and paying a large privacy cost since there is no randomness in the selection, or 2) pausing the protocol and not making progress until more devices are available.

Achieving Provable Differential Privacy for Federated Learning with DP-FTRL
To address this challenge, the DP-FTRL algorithm is built on two key observations: 1) the convergence of gradient-descent-style algorithms depends primarily not on the accuracy of individual gradients, but the accuracy of cumulative sums of gradients; and 2) we can provide accurate estimates of cumulative sums with a strong DP guarantee by utilizing negatively correlated noise, added by the aggregating server: essentially, adding noise to one gradient and subtracting that same noise from a later gradient. DP-FTRL accomplishes this efficiently using the Tree Aggregation algorithm [1, 2].

The graphic below illustrates how estimating cumulative sums rather than individual gradients can help. We look at how the noise introduced by DP-FTRL and DP-SGD influence model training, compared to the true gradients (without added noise; in black) which step one unit to the right on each iteration. The individual DP-FTRL gradient estimates (blue), based on cumulative sums, have larger mean-squared-error than the individually-noised DP-SGD estimates (orange), but because the DP-FTRL noise is negatively correlated, some of it cancels out from step to step, and the overall learning trajectory stays closer to the true gradient descent steps.

To provide a strong privacy guarantee, we limit the number of times a user contributes an update. Fortunately, sampling-without-replacement is relatively easy to implement in production FL infrastructure: each device can remember locally which models it has contributed to in the past, and choose to not connect to the server for any later rounds for those models.

Production Training Details and Formal DP Statements
For the production DP-FTRL deployment introduced above, each eligible device maintains a local training cache consisting of user keyboard input, and when participating computes an update to the model which makes it more likely to suggest the next word the user actually typed, based on what has been typed so far. We ran DP-FTRL on this data to train a recurrent neural network with ~1.3M parameters. Training ran for 2000 rounds over six days, with 6500 devices participating per round. To allow for the DP guarantee, devices participated in training at most once every 24 hours. Model quality improved over the previous DP-FedAvg trained model, which offered empirically-tested privacy advantages over non-DP models, but lacked a meaningful formal DP guarantee.

The training mechanism we used is available in open-source in TensorFlow Federated and TensorFlow Privacy, and with the parameters used in our production deployment it provides a meaningfully strong privacy guarantee. Our analysis gives ρ=0.81 zCDP at the user level (treating all the data on each device as a different user), where smaller numbers correspond to better privacy in a mathematically precise way. As a comparison, this is stronger than the ρ=2.63 zCDP guarantee chosen by the 2020 US Census.

Next Steps
While we have reached the milestone of deploying a production FL model using a mechanism that provides a meaningfully small zCDP, our research journey continues. We are still far from being able to say this approach is possible (let alone practical) for most ML models or product applications, and other approaches to private ML exist. For example, membership inference tests and other empirical privacy auditing techniques can provide complimentary safeguards against leakage of users’ data. Most importantly, we see training models with user-level DP with even a very large zCDP as a substantial step forward, because it requires training with a DP mechanism that bounds the sensitivity of the model to any one user's data. Further, it smooths the road to later training models with improved privacy guarantees as better algorithms or more data become available. We are excited to continue the journey toward maximizing the value that ML can deliver while minimizing potential privacy costs to those who contribute training data.

The authors would like to thank Alex Ingerman and Om Thakkar for significant impact on the blog post itself, as well as the teams at Google that helped develop these ideas and bring them to practice:

  • Core research team: Galen Andrew, Borja Balle, Peter Kairouz, Daniel Ramage, Shuang Song, Thomas Steinke, Andreas Terzis, Om Thakkar, Zheng Xu
  • FL infrastructure team: Katharine Daly, Stefan Dierauf, Hubert Eichner, Igor Pisarev, Timon Van Overveldt, Chunxiang Zheng
  • Gboard team: Angana Ghosh, Xu Liu, Yuanbo Zhang
  • Speech team: Françoise Beaufays, Mingqing Chen, Rajiv Mathews, Vidush Mukund, Igor Pisarev, Swaroop Ramaswamy, Dan Zivkovic

Source: Google AI Blog

A Scalable Approach for Partially Local Federated Learning

Federated learning enables users to train a model without sending raw data to a central server, thus avoiding the collection of privacy-sensitive data. Often this is done by learning a single global model for all users, even though the users may differ in their data distributions. For example, users of a mobile keyboard application may collaborate to train a suggestion model but have different preferences for the suggestions. This heterogeneity has motivated algorithms that can personalize a global model for each user.

However, in some settings privacy considerations may prohibit learning a fully global model. Consider models with user-specific embeddings, such as matrix factorization models for recommender systems. Training a fully global federated model would involve sending user embedding updates to a central server, which could potentially reveal the preferences encoded in the embeddings. Even for models without user-specific embeddings, having some parameters be completely local to user devices would reduce server-client communication and responsibly personalize those parameters to each user.

Left: A matrix factorization model with a user matrix P and items matrix Q. The user embedding for a user u (Pu) and item embedding for item i (Qi) are trained to predict the user’s rating for that item (Rui). Right: Applying federated learning approaches to learn a global model can involve sending updates for Pu to a central server, potentially leaking individual user preferences.

In “Federated Reconstruction: Partially Local Federated Learning”, presented at NeurIPS 2021, we introduce an approach that enables scalable partially local federated learning, where some model parameters are never aggregated on the server. For matrix factorization, this approach trains a recommender model while keeping user embeddings local to each user device. For other models, this approach trains a portion of the model to be completely personal for each user while avoiding communication of these parameters. We successfully deployed partially local federated learning to Gboard, resulting in better recommendations for hundreds of millions of keyboard users. We’re also releasing a TensorFlow Federated tutorial demonstrating how to use Federated Reconstruction.

Federated Reconstruction
Previous approaches for partially local federated learning used stateful algorithms, which require user devices to store a state across rounds of federated training. Specifically, these approaches required devices to store local parameters across rounds. However, these algorithms tend to degrade in large-scale federated learning settings. In these cases, the majority of users do not participate in training, and users who do participate likely only do so once, resulting in a state that is rarely available and can get stale across rounds. Also, all users who do not participate are left without trained local parameters, preventing practical applications.

Federated Reconstruction is stateless and avoids the need for user devices to store local parameters by reconstructing them whenever needed. When a user participates in training, before updating any globally aggregated model parameters, they randomly initialize and train their local parameters using gradient descent on local data with global parameters frozen. They can then calculate updates to global parameters with local parameters frozen. A round of Federated Reconstruction training is depicted below.

Models are partitioned into global and local parameters. For each round of Federated Reconstruction training: (1) The server sends the current global parameters g to each user i; (2) Each user i freezes g and reconstructs their local parameters li; (3) Each user i freezes li and updates g to produce gi; (4) Users’ gi are averaged to produce the global parameters for the next round. Steps (2) and (3) generally use distinct parts of the local data.

This simple approach avoids the challenges of previous methods. It does not assume users have a state from previous rounds of training, enabling large-scale training, and local parameters are always freshly reconstructed, preventing staleness. Users unseen during training can still get trained models and perform inference by simply reconstructing local parameters using local data.

Federated Reconstruction trains better performing models for unseen users compared to other approaches. For a matrix factorization task with unseen users, the approach significantly outperforms both centralized training and baseline Federated Averaging.

RMSE ↓ Accuracy ↑
Centralized 1.36 40.8%
FedAvg .934 40.0%
FedRecon (this work) .907 43.3%
Root-mean-square-error (lower is better) and accuracy for a matrix factorization task with unseen users. Centralized training and Federated Averaging (FedAvg) both reveal privacy-sensitive user embeddings to a central server, while Federated Reconstruction (FedRecon) avoids this.

These results can be explained via a connection to meta learning (i.e., learning to learn); Federated Reconstruction trains global parameters that lead to fast and accurate reconstruction of local parameters for unseen users. That is, Federated Reconstruction is learning to learn local parameters. In practice, we observe that just one gradient descent step can yield successful reconstruction, even for models with about one million local parameters.

Federated Reconstruction also provides a way to personalize models for heterogeneous users while reducing communication of model parameters — even for models without user-specific embeddings. To evaluate this, we apply Federated Reconstruction to personalize a next word prediction language model and observe a substantial increase in performance, attaining accuracy on par with other personalization methods despite reduced communication. Federated Reconstruction also outperforms other personalization methods when executed at a fixed communication level.

Accuracy ↑ Communication ↓
FedYogi 24.3% Whole Model
FedYogi + Finetuning 30.8% Whole Model
FedRecon (this work) 30.7% Partial Model
Accuracy and server-client communication for a next word prediction task without user-specific embeddings. FedYogi communicates all model parameters, while FedRecon avoids this.

Real-World Deployment in Gboard
To validate the practicality of Federated Reconstruction in large-scale settings, we deployed the algorithm to Gboard, a mobile keyboard application with hundreds of millions of users. Gboard users use expressions (e.g., GIFs, stickers) to communicate with others. Users have highly heterogeneous preferences for these expressions, making the setting a good fit for using matrix factorization to predict new expressions a user might want to share.

Gboard users can communicate with expressions, preferences for which are highly personal.

We trained a matrix factorization model over user-expression co-occurrences using Federated Reconstruction, keeping user embeddings local to each Gboard user. We then deployed the model to Gboard users, leading to a 29.3% increase in click-through-rate for expression recommendations. Since most Gboard users were unseen during federated training, Federated Reconstruction played a key role in this deployment.

Further Explorations
We’ve presented Federated Reconstruction, a method for partially local federated learning. Federated Reconstruction enables personalization to heterogeneous users while reducing communication of privacy-sensitive parameters. We scaled the approach to Gboard in alignment with our AI Principles, improving recommendations for hundreds of millions of users.

For a technical walkthrough of Federated Reconstruction for matrix factorization, check out the TensorFlow Federated tutorial. We’ve also released general-purpose TensorFlow Federated libraries and open-source code for running experiments.

Karan Singhal, Hakim Sidahmed, Zachary Garrett, Shanshan Wu, Keith Rush, and Sushant Prakash co-authored the paper. Thanks to Wei Li, Matt Newton, and Yang Lu for their partnership on Gboard deployment. We’d also like to thank Brendan McMahan, Lin Ning, Zachary Charles, Warren Morningstar, Daniel Ramage, Jakub Konecný, Alex Ingerman, Blaise Agüera y Arcas, Jay Yagnik, Bradley Green, and Ewa Dominowska for their helpful comments and support.

Source: Google AI Blog