Tag Archives: Natural Language Processing

Distilling step-by-step: Outperforming larger language models with less training data and smaller model sizes

Large language models (LLMs) have enabled a new data-efficient learning paradigm wherein they can be used to solve unseen new tasks via zero-shot or few-shot prompting. However, LLMs are challenging to deploy for real-world applications due to their sheer size. For instance, serving a single 175 billion LLM requires at least 350GB of GPU memory using specialized infrastructure, not to mention that today's state-of-the-art LLMs are composed of over 500 billion parameters. Such computational requirements are inaccessible for many research teams, especially for applications that require low latency performance.

To circumvent these deployment challenges, practitioners often choose to deploy smaller specialized models instead. These smaller models are trained using one of two common paradigms: fine-tuning or distillation. Fine-tuning updates a pre-trained smaller model (e.g., BERT or T5) using downstream manually-annotated data. Distillation trains the same smaller models with labels generated by a larger LLM. Unfortunately, to achieve comparable performance to LLMs, fine-tuning methods require human-generated labels, which are expensive and tedious to obtain, while distillation requires large amounts of unlabeled data, which can also be hard to collect.

In “Distilling Step-by-Step! Outperforming Larger Language Models with Less Training Data and Smaller Model Sizes”, presented at ACL2023, we set out to tackle this trade-off between model size and training data collection cost. We introduce distilling step-by-step, a new simple mechanism that allows us to train smaller task-specific models with much less training data than required by standard fine-tuning or distillation approaches that outperform few-shot prompted LLMs’ performance. We demonstrate that the distilling step-by-step mechanism enables a 770M parameter T5 model to outperform the few-shot prompted 540B PaLM model using only 80% of examples in a benchmark dataset, which demonstrates a more than 700x model size reduction with much less training data required by standard approaches.

While LLMs offer strong zero and few-shot performance, they are challenging to serve in practice. On the other hand, traditional ways of training small task-specific models require a large amount of training data. Distilling step-by-step provides a new paradigm that reduces both the deployed model size as well as the number of data required for training.


Distilling step-by-step

The key idea of distilling step-by-step is to extract informative natural language rationales (i.e., intermediate reasoning steps) from LLMs, which can in turn be used to train small models in a more data-efficient way. Specifically, natural language rationales explain the connections between the input questions and their corresponding outputs. For example, when asked, “Jesse's room is 11 feet long and 15 feet wide. If she already has 16 square feet of carpet, how much more carpet does she need to cover the whole floor?”, an LLM can be prompted by the few-shot chain-of-thought (CoT) prompting technique to provide intermediate rationales, such as, “Area = length * width. Jesse’s room has 11 * 15 square feet.” That better explains the connection from the input to the final answer, “(11 * 15 ) - 16”. These rationales can contain relevant task knowledge, such as “Area = length * width”, that may originally require many data for small models to learn. We utilize these extracted rationales as additional, richer supervision to train small models, in addition to the standard task labels.

Overview on distilling step-by-step: First, we utilize CoT prompting to extract rationales from an LLM. We then use the generated rationales to train small task-specific models within a multi-task learning framework, where we prepend task prefixes to the input examples and train the model to output differently based on the given task prefix.

Distilling step-by-step consists of two main stages. In the first stage, we leverage few-shot CoT prompting to extract rationales from LLMs. Specifically, given a task, we prepare few-shot exemplars in the LLM input prompt where each example is composed of a triplet containing: (1) input, (2) rationale, and (3) output. Given the prompt, an LLM is able to mimic the triplet demonstration to generate the rationale for any new input. For instance, in a commonsense question answering task, given the input question “Sammy wanted to go to where the people are. Where might he go? Answer Choices: (a) populated areas, (b) race track, (c) desert, (d) apartment, (e) roadblock”, distilling step-by-step provides the correct answer to the question, “(a) populated areas”, paired with the rationale that provides better connection from the question to the answer, “The answer must be a place with a lot of people. Of the above choices, only populated areas have a lot of people.” By providing CoT examples paired with rationales in the prompt, the in-context learning ability allows LLMs to output corresponding rationales for future unseen inputs.

We use the few-shot CoT prompting, which contains both an example rationale (highlighted in green) and a label (highlighted in blue), to elicit rationales from an LLM on new input examples. The example is from a commonsense question answering task.

After the rationales are extracted, in the second stage, we incorporate the rationales in training small models by framing the training process as a multi-task problem. Specifically, we train the small model with a novel rationale generation task in addition to the standard label prediction task. The rationale generation task enables the model to learn to generate the intermediate reasoning steps for the prediction, and guides the model to better predict the resultant label. We prepend task prefixes (i.e., [label] and [rationale] for label prediction and rationale generation, respectively) to the input examples for the model to differentiate the two tasks.


Experimental setup

In the experiments, we consider a 540B PaLM model as the LLM. For task-specific downstream models, we use T5 models. For CoT prompting, we use the original CoT prompts when available and curate our own examples for new datasets. We conduct the experiments on four benchmark datasets across three different NLP tasks: e-SNLI and ANLI for natural language inference; CQA for commonsense question answering; and SVAMP for arithmetic math word problems. We include two sets of baseline methods. For comparison to few-shot prompted LLMs, we compare to few-shot CoT prompting with a 540B PaLM model. In the paper, we also compare standard task-specific model training to both standard fine-tuning and standard distillation. In this blogpost, we will focus on the comparisons to standard fine-tuning for illustration purposes.


Less training data

Compared to standard fine-tuning, the distilling step-by-step method achieves better performance using much less training data. For instance, on the e-SNLI dataset, we achieve better performance than standard fine-tuning when using only 12.5% of the full dataset (shown in the upper left quadrant below). Similarly, we achieve a dataset size reduction of 75%, 25% and 20% on ANLI, CQA, and SVAMP.

Distilling step-by-step compared to standard fine-tuning using 220M T5 models on varying sizes of human-labeled datasets. On all datasets, distilling step-by-step is able to outperform standard fine-tuning, trained on the full dataset, by using much less training examples.


Smaller deployed model size

Compared to few-shot CoT prompted LLMs, distilling step-by-step achieves better performance using much smaller model sizes. For instance, on the e-SNLI dataset, we achieve better performance than 540B PaLM by using a 220M T5 model. On ANLI, we achieve better performance than 540B PaLM by using a 770M T5 model, which is over 700X smaller. Note that on ANLI, the same 770M T5 model struggles to match PaLM’s performance using standard fine-tuning.

We perform distilling step-by-step and standard fine-tuning on varying sizes of T5 models and compare their performance to LLM baselines, i.e., Few-shot CoT and PINTO Tuning. Distilling step-by-step is able to outperform LLM baselines by using much smaller models, e.g., over 700× smaller models on ANLI. Standard fine-tuning fails to match LLM’s performance using the same model size.


Distilling step-by-step outperforms few-shot LLMs with smaller models using less data

Finally, we explore the smallest model sizes and the least amount of data for distilling step-by-step to outperform PaLM’s few-shot performance. For instance, on ANLI, we surpass the performance of the 540B PaLM using a 770M T5 model. This smaller model only uses 80% of the full dataset. Meanwhile, we observe that standard fine-tuning cannot catch up with PaLM’s performance even using 100% of the full dataset. This suggests that distilling step-by-step simultaneously reduces the model size as well as the amount of data required to outperform LLMs.

We show the minimum size of T5 models and the least amount of human-labeled examples required for distilling step-by-step to outperform LLM’s few-shot CoT by a coarse-grained search. Distilling step-by-step is able to outperform few-shot CoT using not only much smaller models, but it also achieves so with much less training examples compared to standard fine-tuning.

Conclusion

We propose distilling step-by-step, a novel mechanism that extracts rationales from LLMs as informative supervision in training small, task-specific models. We show that distilling step-by-step reduces both the training dataset required to curate task-specific smaller models and the model size required to achieve, and even surpass, a few-shot prompted LLM’s performance. Overall, distilling step-by-step presents a resource-efficient paradigm that tackles the trade-off between model size and training data required.


Availability on Google Cloud Platform

Distilling step-by-step is available for private preview on Vertex AI. If you are interested in trying it out, please contact [email protected] with your Google Cloud Project number and a summary of your use case.


Acknowledgements

This research was conducted by Cheng-Yu Hsieh, Chun-Liang Li, Chih-Kuan Yeh, Hootan Nakhost, Yasuhisa Fujii, Alexander Ratner, Ranjay Krishna, Chen-Yu Lee, and Tomas Pfister. Thanks to Xiang Zhang and Sergey Ioffe for their valuable feedback.

Source: Google AI Blog


Symbol tuning improves in-context learning in language models

A key feature of human intelligence is that humans can learn to perform new tasks by reasoning using only a few examples. Scaling up language models has unlocked a range of new applications and paradigms in machine learning, including the ability to perform challenging reasoning tasks via in-context learning. Language models, however, are still sensitive to the way that prompts are given, indicating that they are not reasoning in a robust manner. For instance, language models often require heavy prompt engineering or phrasing tasks as instructions, and they exhibit unexpected behaviors such as performance on tasks being unaffected even when shown incorrect labels.

In “Symbol tuning improves in-context learning in language models”, we propose a simple fine-tuning procedure that we call symbol tuning, which can improve in-context learning by emphasizing input–label mappings. We experiment with symbol tuning across Flan-PaLM models and observe benefits across various settings.

  • Symbol tuning boosts performance on unseen in-context learning tasks and is much more robust to underspecified prompts, such as those without instructions or without natural language labels.
  • Symbol-tuned models are much stronger at algorithmic reasoning tasks.
  • Finally, symbol-tuned models show large improvements in following flipped-labels presented in-context, meaning that they are more capable of using in-context information to override prior knowledge.
An overview of symbol tuning, where models are fine-tuned on tasks where natural language labels are replaced with arbitrary symbols. Symbol tuning relies on the intuition that when instruction and relevant labels are not available, models must use in-context examples to learn the task.

Motivation

Instruction tuning is a common fine-tuning method that has been shown to improve performance and allow models to better follow in-context examples. One shortcoming, however, is that models are not forced to learn to use the examples because the task is redundantly defined in the evaluation example via instructions and natural language labels. For example, on the left in the figure above, although the examples can help the model understand the task (sentiment analysis), they are not strictly necessary since the model could ignore the examples and just read the instruction that indicates what the task is.

In symbol tuning, the model is fine-tuned on examples where the instructions are removed and natural language labels are replaced with semantically-unrelated labels (e.g., “Foo,” “Bar,” etc.). In this setup, the task is unclear without looking at the in-context examples. For example, on the right in the figure above, multiple in-context examples would be needed to figure out the task. Because symbol tuning teaches the model to reason over the in-context examples, symbol-tuned models should have better performance on tasks that require reasoning between in-context examples and their labels.

Datasets and task types used for symbol tuning.

Symbol-tuning procedure

We selected 22 publicly-available natural language processing (NLP) datasets that we use for our symbol-tuning procedure. These tasks have been widely used in the past, and we only chose classification-type tasks since our method requires discrete labels. We then remap labels to a random label from a set of ~30K arbitrary labels selected from one of three categories: integers, character combinations, and words.

For our experiments, we symbol tune Flan-PaLM, the instruction-tuned variants of PaLM. We use three different sizes of Flan-PaLM models: Flan-PaLM-8B, Flan-PaLM-62B, and Flan-PaLM-540B. We also tested Flan-cont-PaLM-62B (Flan-PaLM-62B at 1.3T tokens instead of 780B tokens), which we abbreviate as 62B-c.

We use a set of ∼300K arbitrary symbols from three categories (integers, character combinations, and words). ∼30K symbols are used during tuning and the rest are held out for evaluation.

Experimental setup

We want to evaluate a model’s ability to perform unseen tasks, so we cannot evaluate on tasks used in symbol tuning (22 datasets) or used during instruction tuning (1.8K tasks). Hence, we choose 11 NLP datasets that were not used during fine-tuning.


In-context learning

In the symbol-tuning procedure, models must learn to reason with in-context examples in order to successfully perform tasks because prompts are modified to ensure that tasks cannot simply be learned from relevant labels or instructions. Symbol-tuned models should perform better in settings where tasks are unclear and require reasoning between in-context examples and their labels. To explore these settings, we define four in-context learning settings that vary the amount of reasoning required between inputs and labels in order to learn the task (based on the availability of instructions/relevant labels)

Depending on the availability of instructions and relevant natural language labels, models may need to do varying amounts of reasoning with in-context examples. When these features are not available, models must reason with the given in-context examples to successfully perform the task.

Symbol tuning improves performance across all settings for models 62B and larger, with small improvements in settings with relevant natural language labels (+0.8% to +4.2%) and substantial improvements in settings without relevant natural language labels (+5.5% to +15.5%). Strikingly, when relevant labels are unavailable, symbol-tuned Flan-PaLM-8B outperforms FlanPaLM-62B, and symbol-tuned Flan-PaLM-62B outperforms Flan-PaLM-540B. This performance difference suggests that symbol tuning can allow much smaller models to perform as well as large models on these tasks (effectively saving ∼10X inference compute).

Large-enough symbol-tuned models are better at in-context learning than baselines, especially in settings where relevant labels are not available. Performance is shown as average model accuracy (%) across eleven tasks.

Algorithmic reasoning

We also experiment on algorithmic reasoning tasks from BIG-Bench. There are two main groups of tasks: 1) List functions — identify a transformation function (e.g., remove the last element in a list) between input and output lists containing non-negative integers; and 2) simple turing concepts — reason with binary strings to learn the concept that maps an input to an output (e.g., swapping 0s and 1s in a string).

On the list function and simple turing concept tasks, symbol tuning results in an average performance improvement of 18.2% and 15.3%, respectively. Additionally, Flan-cont-PaLM-62B with symbol tuning outperforms Flan-PaLM-540B on the list function tasks on average, which is equivalent to a ∼10x reduction in inference compute. These improvements suggest that symbol tuning strengthens the model’s ability to learn in-context for unseen task types, as symbol tuning did not include any algorithmic data.

Symbol-tuned models achieve higher performance on list function tasks and simple turing concept tasks. (A–E): categories of list functions tasks. (F): simple turing concepts task.

Flipped labels

In the flipped-label experiment, labels of in-context and evaluation examples are flipped, meaning that prior knowledge and input-label mappings disagree (e.g., sentences containing positive sentiment labeled as “negative sentiment”), thereby allowing us to study whether models can override prior knowledge. Previous work has shown that while pre-trained models (without instruction tuning) can, to some extent, follow flipped labels presented in-context, instruction tuning degraded this ability.

We see that there is a similar trend across all model sizes — symbol-tuned models are much more capable of following flipped labels than instruction-tuned models. We found that after symbol tuning, Flan-PaLM-8B sees an average improvement across all datasets of 26.5%, Flan-PaLM-62B sees an improvement of 33.7%, and Flan-PaLM-540B sees an improvement of 34.0%. Additionally, symbol-tuned models achieve similar or better than average performance as pre-training–only models.

Symbol-tuned models are much better at following flipped labels presented in-context than instruction-tuned models are.

Conclusion

We presented symbol tuning, a new method of tuning models on tasks where natural language labels are remapped to arbitrary symbols. Symbol tuning is based off of the intuition that when models cannot use instructions or relevant labels to determine a presented task, it must do so by instead learning from in-context examples. We tuned four language models using our symbol-tuning procedure, utilizing a tuning mixture of 22 datasets and approximately 30K arbitrary symbols as labels.

We first showed that symbol tuning improves performance on unseen in-context learning tasks, especially when prompts do not contain instructions or relevant labels. We also found that symbol-tuned models were much better at algorithmic reasoning tasks, despite the lack of numerical or algorithmic data in the symbol-tuning procedure. Finally, in an in-context learning setting where inputs have flipped labels, symbol tuning (for some datasets) restores the ability to follow flipped labels that was lost during instruction tuning.


Future work

Through symbol tuning, we aim to increase the degree to which models can examine and learn from input–label mappings during in-context learning. We hope that our results encourage further work towards improving language models’ ability to reason over symbols presented in-context.


Acknowledgements

The authors of this post are now part of Google DeepMind. This work was conducted by Jerry Wei, Le Hou, Andrew Lampinen, Xiangning Chen, Da Huang, Yi Tay, Xinyun Chen, Yifeng Lu, Denny Zhou, Tengyu Ma, and Quoc V. Le. We would like to thank our colleagues at Google Research and Google DeepMind for their advice and helpful discussions.

Source: Google AI Blog


Symbol tuning improves in-context learning in language models

A key feature of human intelligence is that humans can learn to perform new tasks by reasoning using only a few examples. Scaling up language models has unlocked a range of new applications and paradigms in machine learning, including the ability to perform challenging reasoning tasks via in-context learning. Language models, however, are still sensitive to the way that prompts are given, indicating that they are not reasoning in a robust manner. For instance, language models often require heavy prompt engineering or phrasing tasks as instructions, and they exhibit unexpected behaviors such as performance on tasks being unaffected even when shown incorrect labels.

In “Symbol tuning improves in-context learning in language models”, we propose a simple fine-tuning procedure that we call symbol tuning, which can improve in-context learning by emphasizing input–label mappings. We experiment with symbol tuning across Flan-PaLM models and observe benefits across various settings.

  • Symbol tuning boosts performance on unseen in-context learning tasks and is much more robust to underspecified prompts, such as those without instructions or without natural language labels.
  • Symbol-tuned models are much stronger at algorithmic reasoning tasks.
  • Finally, symbol-tuned models show large improvements in following flipped-labels presented in-context, meaning that they are more capable of using in-context information to override prior knowledge.
An overview of symbol tuning, where models are fine-tuned on tasks where natural language labels are replaced with arbitrary symbols. Symbol tuning relies on the intuition that when instruction and relevant labels are not available, models must use in-context examples to learn the task.

Motivation

Instruction tuning is a common fine-tuning method that has been shown to improve performance and allow models to better follow in-context examples. One shortcoming, however, is that models are not forced to learn to use the examples because the task is redundantly defined in the evaluation example via instructions and natural language labels. For example, on the left in the figure above, although the examples can help the model understand the task (sentiment analysis), they are not strictly necessary since the model could ignore the examples and just read the instruction that indicates what the task is.

In symbol tuning, the model is fine-tuned on examples where the instructions are removed and natural language labels are replaced with semantically-unrelated labels (e.g., “Foo,” “Bar,” etc.). In this setup, the task is unclear without looking at the in-context examples. For example, on the right in the figure above, multiple in-context examples would be needed to figure out the task. Because symbol tuning teaches the model to reason over the in-context examples, symbol-tuned models should have better performance on tasks that require reasoning between in-context examples and their labels.

Datasets and task types used for symbol tuning.

Symbol-tuning procedure

We selected 22 publicly-available natural language processing (NLP) datasets that we use for our symbol-tuning procedure. These tasks have been widely used in the past, and we only chose classification-type tasks since our method requires discrete labels. We then remap labels to a random label from a set of ~30K arbitrary labels selected from one of three categories: integers, character combinations, and words.

For our experiments, we symbol tune Flan-PaLM, the instruction-tuned variants of PaLM. We use three different sizes of Flan-PaLM models: Flan-PaLM-8B, Flan-PaLM-62B, and Flan-PaLM-540B. We also tested Flan-cont-PaLM-62B (Flan-PaLM-62B at 1.3T tokens instead of 780B tokens), which we abbreviate as 62B-c.

We use a set of ∼300K arbitrary symbols from three categories (integers, character combinations, and words). ∼30K symbols are used during tuning and the rest are held out for evaluation.

Experimental setup

We want to evaluate a model’s ability to perform unseen tasks, so we cannot evaluate on tasks used in symbol tuning (22 datasets) or used during instruction tuning (1.8K tasks). Hence, we choose 11 NLP datasets that were not used during fine-tuning.


In-context learning

In the symbol-tuning procedure, models must learn to reason with in-context examples in order to successfully perform tasks because prompts are modified to ensure that tasks cannot simply be learned from relevant labels or instructions. Symbol-tuned models should perform better in settings where tasks are unclear and require reasoning between in-context examples and their labels. To explore these settings, we define four in-context learning settings that vary the amount of reasoning required between inputs and labels in order to learn the task (based on the availability of instructions/relevant labels)

Depending on the availability of instructions and relevant natural language labels, models may need to do varying amounts of reasoning with in-context examples. When these features are not available, models must reason with the given in-context examples to successfully perform the task.

Symbol tuning improves performance across all settings for models 62B and larger, with small improvements in settings with relevant natural language labels (+0.8% to +4.2%) and substantial improvements in settings without relevant natural language labels (+5.5% to +15.5%). Strikingly, when relevant labels are unavailable, symbol-tuned Flan-PaLM-8B outperforms FlanPaLM-62B, and symbol-tuned Flan-PaLM-62B outperforms Flan-PaLM-540B. This performance difference suggests that symbol tuning can allow much smaller models to perform as well as large models on these tasks (effectively saving ∼10X inference compute).

Large-enough symbol-tuned models are better at in-context learning than baselines, especially in settings where relevant labels are not available. Performance is shown as average model accuracy (%) across eleven tasks.

Algorithmic reasoning

We also experiment on algorithmic reasoning tasks from BIG-Bench. There are two main groups of tasks: 1) List functions — identify a transformation function (e.g., remove the last element in a list) between input and output lists containing non-negative integers; and 2) simple turing concepts — reason with binary strings to learn the concept that maps an input to an output (e.g., swapping 0s and 1s in a string).

On the list function and simple turing concept tasks, symbol tuning results in an average performance improvement of 18.2% and 15.3%, respectively. Additionally, Flan-cont-PaLM-62B with symbol tuning outperforms Flan-PaLM-540B on the list function tasks on average, which is equivalent to a ∼10x reduction in inference compute. These improvements suggest that symbol tuning strengthens the model’s ability to learn in-context for unseen task types, as symbol tuning did not include any algorithmic data.

Symbol-tuned models achieve higher performance on list function tasks and simple turing concept tasks. (A–E): categories of list functions tasks. (F): simple turing concepts task.

Flipped labels

In the flipped-label experiment, labels of in-context and evaluation examples are flipped, meaning that prior knowledge and input-label mappings disagree (e.g., sentences containing positive sentiment labeled as “negative sentiment”), thereby allowing us to study whether models can override prior knowledge. Previous work has shown that while pre-trained models (without instruction tuning) can, to some extent, follow flipped labels presented in-context, instruction tuning degraded this ability.

We see that there is a similar trend across all model sizes — symbol-tuned models are much more capable of following flipped labels than instruction-tuned models. We found that after symbol tuning, Flan-PaLM-8B sees an average improvement across all datasets of 26.5%, Flan-PaLM-62B sees an improvement of 33.7%, and Flan-PaLM-540B sees an improvement of 34.0%. Additionally, symbol-tuned models achieve similar or better than average performance as pre-training–only models.

Symbol-tuned models are much better at following flipped labels presented in-context than instruction-tuned models are.

Conclusion

We presented symbol tuning, a new method of tuning models on tasks where natural language labels are remapped to arbitrary symbols. Symbol tuning is based off of the intuition that when models cannot use instructions or relevant labels to determine a presented task, it must do so by instead learning from in-context examples. We tuned four language models using our symbol-tuning procedure, utilizing a tuning mixture of 22 datasets and approximately 30K arbitrary symbols as labels.

We first showed that symbol tuning improves performance on unseen in-context learning tasks, especially when prompts do not contain instructions or relevant labels. We also found that symbol-tuned models were much better at algorithmic reasoning tasks, despite the lack of numerical or algorithmic data in the symbol-tuning procedure. Finally, in an in-context learning setting where inputs have flipped labels, symbol tuning (for some datasets) restores the ability to follow flipped labels that was lost during instruction tuning.


Future work

Through symbol tuning, we aim to increase the degree to which models can examine and learn from input–label mappings during in-context learning. We hope that our results encourage further work towards improving language models’ ability to reason over symbols presented in-context.


Acknowledgements

The authors of this post are now part of Google DeepMind. This work was conducted by Jerry Wei, Le Hou, Andrew Lampinen, Xiangning Chen, Da Huang, Yi Tay, Xinyun Chen, Yifeng Lu, Denny Zhou, Tengyu Ma, and Quoc V. Le. We would like to thank our colleagues at Google Research and Google DeepMind for their advice and helpful discussions.

Source: Google AI Blog


Symbol tuning improves in-context learning in language models

A key feature of human intelligence is that humans can learn to perform new tasks by reasoning using only a few examples. Scaling up language models has unlocked a range of new applications and paradigms in machine learning, including the ability to perform challenging reasoning tasks via in-context learning. Language models, however, are still sensitive to the way that prompts are given, indicating that they are not reasoning in a robust manner. For instance, language models often require heavy prompt engineering or phrasing tasks as instructions, and they exhibit unexpected behaviors such as performance on tasks being unaffected even when shown incorrect labels.

In “Symbol tuning improves in-context learning in language models”, we propose a simple fine-tuning procedure that we call symbol tuning, which can improve in-context learning by emphasizing input–label mappings. We experiment with symbol tuning across Flan-PaLM models and observe benefits across various settings.

  • Symbol tuning boosts performance on unseen in-context learning tasks and is much more robust to underspecified prompts, such as those without instructions or without natural language labels.
  • Symbol-tuned models are much stronger at algorithmic reasoning tasks.
  • Finally, symbol-tuned models show large improvements in following flipped-labels presented in-context, meaning that they are more capable of using in-context information to override prior knowledge.
An overview of symbol tuning, where models are fine-tuned on tasks where natural language labels are replaced with arbitrary symbols. Symbol tuning relies on the intuition that when instruction and relevant labels are not available, models must use in-context examples to learn the task.

Motivation

Instruction tuning is a common fine-tuning method that has been shown to improve performance and allow models to better follow in-context examples. One shortcoming, however, is that models are not forced to learn to use the examples because the task is redundantly defined in the evaluation example via instructions and natural language labels. For example, on the left in the figure above, although the examples can help the model understand the task (sentiment analysis), they are not strictly necessary since the model could ignore the examples and just read the instruction that indicates what the task is.

In symbol tuning, the model is fine-tuned on examples where the instructions are removed and natural language labels are replaced with semantically-unrelated labels (e.g., “Foo,” “Bar,” etc.). In this setup, the task is unclear without looking at the in-context examples. For example, on the right in the figure above, multiple in-context examples would be needed to figure out the task. Because symbol tuning teaches the model to reason over the in-context examples, symbol-tuned models should have better performance on tasks that require reasoning between in-context examples and their labels.

Datasets and task types used for symbol tuning.

Symbol-tuning procedure

We selected 22 publicly-available natural language processing (NLP) datasets that we use for our symbol-tuning procedure. These tasks have been widely used in the past, and we only chose classification-type tasks since our method requires discrete labels. We then remap labels to a random label from a set of ~30K arbitrary labels selected from one of three categories: integers, character combinations, and words.

For our experiments, we symbol tune Flan-PaLM, the instruction-tuned variants of PaLM. We use three different sizes of Flan-PaLM models: Flan-PaLM-8B, Flan-PaLM-62B, and Flan-PaLM-540B. We also tested Flan-cont-PaLM-62B (Flan-PaLM-62B at 1.3T tokens instead of 780B tokens), which we abbreviate as 62B-c.

We use a set of ∼300K arbitrary symbols from three categories (integers, character combinations, and words). ∼30K symbols are used during tuning and the rest are held out for evaluation.

Experimental setup

We want to evaluate a model’s ability to perform unseen tasks, so we cannot evaluate on tasks used in symbol tuning (22 datasets) or used during instruction tuning (1.8K tasks). Hence, we choose 11 NLP datasets that were not used during fine-tuning.


In-context learning

In the symbol-tuning procedure, models must learn to reason with in-context examples in order to successfully perform tasks because prompts are modified to ensure that tasks cannot simply be learned from relevant labels or instructions. Symbol-tuned models should perform better in settings where tasks are unclear and require reasoning between in-context examples and their labels. To explore these settings, we define four in-context learning settings that vary the amount of reasoning required between inputs and labels in order to learn the task (based on the availability of instructions/relevant labels)

Depending on the availability of instructions and relevant natural language labels, models may need to do varying amounts of reasoning with in-context examples. When these features are not available, models must reason with the given in-context examples to successfully perform the task.

Symbol tuning improves performance across all settings for models 62B and larger, with small improvements in settings with relevant natural language labels (+0.8% to +4.2%) and substantial improvements in settings without relevant natural language labels (+5.5% to +15.5%). Strikingly, when relevant labels are unavailable, symbol-tuned Flan-PaLM-8B outperforms FlanPaLM-62B, and symbol-tuned Flan-PaLM-62B outperforms Flan-PaLM-540B. This performance difference suggests that symbol tuning can allow much smaller models to perform as well as large models on these tasks (effectively saving ∼10X inference compute).

Large-enough symbol-tuned models are better at in-context learning than baselines, especially in settings where relevant labels are not available. Performance is shown as average model accuracy (%) across eleven tasks.

Algorithmic reasoning

We also experiment on algorithmic reasoning tasks from BIG-Bench. There are two main groups of tasks: 1) List functions — identify a transformation function (e.g., remove the last element in a list) between input and output lists containing non-negative integers; and 2) simple turing concepts — reason with binary strings to learn the concept that maps an input to an output (e.g., swapping 0s and 1s in a string).

On the list function and simple turing concept tasks, symbol tuning results in an average performance improvement of 18.2% and 15.3%, respectively. Additionally, Flan-cont-PaLM-62B with symbol tuning outperforms Flan-PaLM-540B on the list function tasks on average, which is equivalent to a ∼10x reduction in inference compute. These improvements suggest that symbol tuning strengthens the model’s ability to learn in-context for unseen task types, as symbol tuning did not include any algorithmic data.

Symbol-tuned models achieve higher performance on list function tasks and simple turing concept tasks. (A–E): categories of list functions tasks. (F): simple turing concepts task.

Flipped labels

In the flipped-label experiment, labels of in-context and evaluation examples are flipped, meaning that prior knowledge and input-label mappings disagree (e.g., sentences containing positive sentiment labeled as “negative sentiment”), thereby allowing us to study whether models can override prior knowledge. Previous work has shown that while pre-trained models (without instruction tuning) can, to some extent, follow flipped labels presented in-context, instruction tuning degraded this ability.

We see that there is a similar trend across all model sizes — symbol-tuned models are much more capable of following flipped labels than instruction-tuned models. We found that after symbol tuning, Flan-PaLM-8B sees an average improvement across all datasets of 26.5%, Flan-PaLM-62B sees an improvement of 33.7%, and Flan-PaLM-540B sees an improvement of 34.0%. Additionally, symbol-tuned models achieve similar or better than average performance as pre-training–only models.

Symbol-tuned models are much better at following flipped labels presented in-context than instruction-tuned models are.

Conclusion

We presented symbol tuning, a new method of tuning models on tasks where natural language labels are remapped to arbitrary symbols. Symbol tuning is based off of the intuition that when models cannot use instructions or relevant labels to determine a presented task, it must do so by instead learning from in-context examples. We tuned four language models using our symbol-tuning procedure, utilizing a tuning mixture of 22 datasets and approximately 30K arbitrary symbols as labels.

We first showed that symbol tuning improves performance on unseen in-context learning tasks, especially when prompts do not contain instructions or relevant labels. We also found that symbol-tuned models were much better at algorithmic reasoning tasks, despite the lack of numerical or algorithmic data in the symbol-tuning procedure. Finally, in an in-context learning setting where inputs have flipped labels, symbol tuning (for some datasets) restores the ability to follow flipped labels that was lost during instruction tuning.


Future work

Through symbol tuning, we aim to increase the degree to which models can examine and learn from input–label mappings during in-context learning. We hope that our results encourage further work towards improving language models’ ability to reason over symbols presented in-context.


Acknowledgements

The authors of this post are now part of Google DeepMind. This work was conducted by Jerry Wei, Le Hou, Andrew Lampinen, Xiangning Chen, Da Huang, Yi Tay, Xinyun Chen, Yifeng Lu, Denny Zhou, Tengyu Ma, and Quoc V. Le. We would like to thank our colleagues at Google Research and Google DeepMind for their advice and helpful discussions.

Source: Google AI Blog


Pic2Word: Mapping pictures to words for zero-shot composed image retrieval

Image retrieval plays a crucial role in search engines. Typically, their users rely on either image or text as a query to retrieve a desired target image. However, text-based retrieval has its limitations, as describing the target image accurately using words can be challenging. For instance, when searching for a fashion item, users may want an item whose specific attribute, e.g., the color of a logo or the logo itself, is different from what they find in a website. Yet searching for the item in an existing search engine is not trivial since precisely describing the fashion item by text can be challenging. To address this fact, composed image retrieval (CIR) retrieves images based on a query that combines both an image and a text sample that provides instructions on how to modify the image to fit the intended retrieval target. Thus, CIR allows precise retrieval of the target image by combining image and text.

However, CIR methods require large amounts of labeled data, i.e., triplets of a 1) query image, 2) description, and 3) target image. Collecting such labeled data is costly, and models trained on this data are often tailored to a specific use case, limiting their ability to generalize to different datasets.

To address these challenges, in “Pic2Word: Mapping Pictures to Words for Zero-shot Composed Image Retrieval”, we propose a task called zero-shot CIR (ZS-CIR). In ZS-CIR, we aim to build a single CIR model that performs a variety of CIR tasks, such as object composition, attribute editing, or domain conversion, without requiring labeled triplet data. Instead, we propose to train a retrieval model using large-scale image-caption pairs and unlabeled images, which are considerably easier to collect than supervised CIR datasets at scale. To encourage reproducibility and further advance this space, we also release the code.

Description of existing composed image retrieval model.
We train a composed image retrieval model using image-caption data only. Our model retrieves images aligned with the composition of the query image and text.

Method overview

We propose to leverage the language capabilities of the language encoder in the contrastive language-image pre-trained model (CLIP), which excels at generating semantically meaningful language embeddings for a wide range of textual concepts and attributes. To that end, we use a lightweight mapping sub-module in CLIP that is designed to map an input picture (e.g., a photo of a cat) from the image embedding space to a word token (e.g., “cat”) in the textual input space. The whole network is optimized with the vision-language contrastive loss to again ensure the visual and text embedding spaces are as close as possible given a pair of an image and its textual description. Then, the query image can be treated as if it is a word. This enables the flexible and seamless composition of query image features and text descriptions by the language encoder. We call our method Pic2Word and provide an overview of its training process in the figure below. We want the mapped token s to represent the input image in the form of word token. Then, we train the mapping network to reconstruct the image embedding in the language embedding, p. Specifically, we optimize the contrastive loss proposed in CLIP computed between the visual embedding v and the textual embedding p.

Training of the mapping network (fM) using unlabeled images only. We optimize only the mapping network with a frozen visual and text encoder.

Given the trained mapping network, we can regard an image as a word token and pair it with the text description to flexibly compose the joint image-text query as shown in the figure below.

With the trained mapping network, we regard the image as a word token and pair it with the text description to flexibly compose the joint image-text query.

Evaluation

We conduct a variety of experiments to evaluate Pic2Word’s performance on a variety of CIR tasks.


Domain conversion

We first evaluate the capability of compositionality of the proposed method on domain conversion — given an image and the desired new image domain (e.g., sculpture, origami, cartoon, toy), the output of the system should be an image with the same content but in the new desired image domain or style. As illustrated below, we evaluate the ability to compose the category information and domain description given as an image and text, respectively. We evaluate the conversion from real images to four domains using ImageNet and ImageNet-R.

To compare with approaches that do not require supervised training data, we pick three approaches: (i) image only performs retrieval only with visual embedding, (ii) text only employs only text embedding, and (iii) image + text averages the visual and text embedding to compose the query. The comparison with (iii) shows the importance of composing image and text using a language encoder. We also compare with Combiner, which trains the CIR model on Fashion-IQ or CIRR.

We aim to convert the domain of the input query image into the one described with text, e.g., origami.

As shown in figure below, our proposed approach outperforms baselines by a large margin.

Results (recall@10, i.e., the percentage of relevant instances in the first 10 images retrieved.) on composed image retrieval for domain conversion.

Fashion attribute composition

Next, we evaluate the composition of fashion attributes, such as the color of cloth, logo, and length of sleeve, using the Fashion-IQ dataset. The figure below illustrates the desired output given the query.

Overview of CIR for fashion attributes.

In the figure below, we present a comparison with baselines, including supervised baselines that utilized triplets for training the CIR model: (i) CB uses the same architecture as our approach, (ii) CIRPLANT, ALTEMIS, MAAF use a smaller backbone, such as ResNet50. Comparison to these approaches will give us the understanding on how well our zero-shot approach performs on this task.

Although CB outperforms our approach, our method performs better than supervised baselines with smaller backbones. This result suggests that by utilizing a robust CLIP model, we can train a highly effective CIR model without requiring annotated triplets.

Results (recall@10, i.e., the percentage of relevant instances in the first 10 images retrieved.) on composed image retrieval for Fashion-IQ dataset (higher is better). Light blue bars train the model using triplets. Note that our approach performs on par with these supervised baselines with shallow (smaller) backbones.

Qualitative results

We show several examples in the figure below. Compared to a baseline method that does not require supervised training data (text + image feature averaging), our approach does a better job of correctly retrieving the target image.

Qualitative results on diverse query images and text description.

Conclusion and future work

In this article, we introduce Pic2Word, a method for mapping pictures to words for ZS-CIR. We propose to convert the image into a word token to achieve a CIR model using only an image-caption dataset. Through a variety of experiments, we verify the effectiveness of the trained model on diverse CIR tasks, indicating that training on an image-caption dataset can build a powerful CIR model. One potential future research direction is utilizing caption data to train the mapping network, although we use only image data in the present work.


Acknowledgements

This research was conducted by Kuniaki Saito, Kihyuk Sohn, Xiang Zhang, Chun-Liang Li, Chen-Yu Lee, Kate Saenko, and Tomas Pfister. Also thanks to Zizhao Zhang and Sergey Ioffe for their valuable feedback.

Source: Google AI Blog


Unifying image-caption and image-classification datasets with prefix conditioning

Pre-training visual language (VL) models on web-scale image-caption datasets has recently emerged as a powerful alternative to traditional pre-training on image classification data. Image-caption datasets are considered to be more “open-domain” because they contain broader scene types and vocabulary words, which result in models with strong performance in few- and zero-shot recognition tasks. However, images with fine-grained class descriptions can be rare, and the class distribution can be imbalanced since image-caption datasets do not go through manual curation. By contrast, large-scale classification datasets, such as ImageNet, are often curated and can thus provide fine-grained categories with a balanced label distribution. While it may sound promising, directly combining caption and classification datasets for pre-training is often unsuccessful as it can result in biased representations that do not generalize well to various downstream tasks.

In “Prefix Conditioning Unifies Language and Label Supervision”, presented at CVPR 2023, we demonstrate a pre-training strategy that uses both classification and caption datasets to provide complementary benefits. First, we show that naïvely unifying the datasets results in sub-optimal performance on downstream zero-shot recognition tasks as the model is affected by dataset bias: the coverage of image domains and vocabulary words is different in each dataset. We address this problem during training through prefix conditioning, a novel simple and effective method that uses prefix tokens to disentangle dataset biases from visual concepts. This approach allows the language encoder to learn from both datasets while also tailoring feature extraction to each dataset. Prefix conditioning is a generic method that can be easily integrated into existing VL pre-training objectives, such as Contrastive Language-Image Pre-training (CLIP) or Unified Contrastive Learning (UniCL).


High-level idea

We note that classification datasets tend to be biased in at least two ways: (1) the images mostly contain single objects from restricted domains, and (2) the vocabulary is limited and lacks the linguistic flexibility required for zero-shot learning. For example, the class embedding of “a photo of a dog” optimized for ImageNet usually results in a photo of one dog in the center of the image pulled from the ImageNet dataset, which does not generalize well to other datasets containing images of multiple dogs in different spatial locations or a dog with other subjects.

By contrast, caption datasets contain a wider variety of scene types and vocabularies. As shown below, if a model simply learns from two datasets, the language embedding can entangle the bias from the image classification and caption dataset, which can decrease the generalization in zero-shot classification. If we can disentangle the bias from two datasets, we can use language embeddings that are tailored for the caption dataset to improve generalization.

Top: Language embedding entangling the bias from image classification and caption dataset. Bottom: Language embeddings disentangles the bias from two datasets.


Prefix conditioning

Prefix conditioning is partially inspired by prompt tuning, which prepends learnable tokens to the input token sequences to instruct a pre-trained model backbone to learn task-specific knowledge that can be used to solve downstream tasks. The prefix conditioning approach differs from prompt tuning in two ways: (1) it is designed to unify image-caption and classification datasets by disentangling the dataset bias, and (2) it is applied to VL pre-training while the standard prompt tuning is used to fine-tune models. Prefix conditioning is an explicit way to specifically steer the behavior of model backbones based on the type of datasets provided by users. This is especially helpful in production when the number of different types of datasets is known ahead of time.

During training, prefix conditioning learns a text token (prefix token) for each dataset type, which absorbs the bias of the dataset and allows the remaining text tokens to focus on learning visual concepts. Specifically, it prepends prefix tokens for each dataset type to the input tokens that inform the language and visual encoder of the input data type (e.g., classification vs. caption). Prefix tokens are trained to learn the dataset-type-specific bias, which enables us to disentangle that bias in language representations and utilize the embedding learned on the image-caption dataset during test time, even without an input caption.

We utilize prefix conditioning for CLIP using a language and visual encoder. During test time, we employ the prefix used for the image-caption dataset since the dataset is supposed to cover broader scene types and vocabulary words, leading to better performance in zero-shot recognition.

Illustration of the Prefix Conditioning.


Experimental results

We apply prefix conditioning to two types of contrastive loss, CLIP and UniCL, and evaluate their performance on zero-shot recognition tasks compared to models trained with ImageNet21K (IN21K) and Conceptual 12M (CC12M). CLIP and UniCL models trained with two datasets using prefix conditioning show large improvements in zero-shot classification accuracy.

Zero-shot classification accuracy of models trained with only IN21K or CC12M compared to CLIP and UniCL models trained with both two datasets using prefix conditioning (“Ours”).


Study on test-time prefix

The table below describes the performance change by the prefix used during test time. We demonstrate that by using the same prefix used for the classification dataset (“Prompt”), the performance on the classification dataset (IN-1K) improves. When using the same prefix used for the image-caption dataset (“Caption”), the performance on other datasets (Zero-shot AVG) improves. This analysis illustrates that if the prefix is tailored for the image-caption dataset, it achieves better generalization of scene types and vocabulary words.

Analysis of the prefix used for test-time.


Study on robustness to image distribution shift

We study the shift in image distribution using ImageNet variants. We see that the “Caption” prefix performs better than “Prompt” in ImageNet-R (IN-R) and ImageNet-Sketch (IN-S), but underperforms in ImageNet-V2 (IN-V2). This indicates that the “Caption” prefix achieves generalization on domains far from the classification dataset. Therefore, the optimal prefix probably differs by how far the test domain is from the classification dataset.

Analysis on the robustness to image-level distribution shift. IN: ImageNet, IN-V2: ImageNet-V2, IN-R: Art, Cartoon style ImageNet, IN-S: ImageNet Sketch.


Conclusion and future work

We introduce prefix conditioning, a technique for unifying image caption and classification datasets for better zero-shot classification. We show that this approach leads to better zero-shot classification accuracy and that the prefix can control the bias in the language embedding. One limitation is that the prefix learned on the caption dataset is not necessarily optimal for the zero-shot classification. Identifying the optimal prefix for each test dataset is an interesting direction for future work.


Acknowledgements

This research was conducted by Kuniaki Saito, Kihyuk Sohn, Xiang Zhang, Chun-Liang Li, Chen-Yu Lee, Kate Saenko, and Tomas Pfister. Thanks to Zizhao Zhang and Sergey Ioffe for their valuable feedback.

Source: Google AI Blog


Foundation models for reasoning on charts

Visual language is the form of communication that relies on pictorial symbols outside of text to convey information. It is ubiquitous in our digital life in the form of iconography, infographics, tables, plots, and charts, extending to the real world in street signs, comic books, food labels, etc. For that reason, having computers better understand this type of media can help with scientific communication and discovery, accessibility, and data transparency.

While computer vision models have made tremendous progress using learning-based solutions since the advent of ImageNet, the focus has been on natural images, where all sorts of tasks, such as classification, visual question answering (VQA), captioning, detection and segmentation, have been defined, studied and in some cases advanced to reach human performance. However, visual language has not garnered a similar level of attention, possibly because of the lack of large-scale training sets in this space. But over the last few years, new academic datasets have been created with the goal of evaluating question answering systems on visual language images, like PlotQA, InfographicsVQA, and ChartQA.

Example from ChartQA. Answering the question requires reading the information and computing the sum and the difference.

Existing models built for these tasks relied on integrating optical character recognition (OCR) information and their coordinates into larger pipelines but the process is error prone, slow, and generalizes poorly. The prevalence of these methods was because existing end-to-end computer vision models based on convolutional neural networks (CNNs) or transformers pre-trained on natural images could not be easily adapted to visual language. But existing models are ill-prepared for the challenges in answering questions on charts, including reading the relative height of bars or the angle of slices in pie charts, understanding axis scales, correctly mapping pictograms with their legend values with colors, sizes and textures, and finally performing numerical operations with the extracted numbers.

In light of these challenges, we propose “MatCha: Enhancing Visual Language Pretraining with Math Reasoning and Chart Derendering”. MatCha, which stands for math and charts, is a pixels-to-text foundation model (a pre-trained model with built-in inductive biases that can be fine-tuned for multiple applications) trained on two complementary tasks: (a) chart de-rendering and (b) math reasoning. In chart de-rendering, given a plot or chart, the image-to-text model is required to generate its underlying data table or the code used to render it. For math reasoning pre-training, we pick textual numerical reasoning datasets and render the input into images, which the image-to-text model needs to decode for answers. We also propose “DePlot: One-shot visual language reasoning by plot-to-table translation”, a model built on top of MatCha for one-shot reasoning on charts via translation to tables. With these methods we surpass the previous state of the art in ChartQA by more than 20% and match the best summarization systems that have 1000 times more parameters. Both papers will be presented at ACL2023.


Chart de-rendering

Plots and charts are usually generated by an underlying data table and a piece of code. The code defines the overall layout of the figure (e.g., type, direction, color/shape scheme) and the underlying data table establishes the actual numbers and their groupings. Both the data and code are sent to a compiler/rendering engine to create the final image. To understand a chart, one needs to discover the visual patterns in the image and effectively parse and group them to extract the key information. Reversing the plot rendering process demands all such capabilities and can thus serve as an ideal pre-training task.

A chart created from a table in the Airbus A380 Wikipedia page using random plotting options. The pre-training task for MatCha consists of recovering the source table or the source code from the image.

In practice, it is challenging to simultaneously obtain charts, their underlying data tables, and their rendering code. To collect sufficient pre-training data, we independently accumulate [chart, code] and [chart, table] pairs. For [chart, code], we crawl all GitHub IPython notebooks with appropriate licenses and extract blocks with figures. A figure and the code block right before it are saved as a [chart, code] pair. For [chart, table] pairs, we explored two sources. For the first source, synthetic data, we manually write code to convert web-crawled Wikipedia tables from the TaPas codebase to charts. We sampled from and combined several plotting options depending on the column types. In addition, we also add [chart, table] pairs generated in PlotQA to diversify the pre-training corpus. The second source is web-crawled [chart, table] pairs. We directly use the [chart, table] pairs crawled in the ChartQA training set, containing around 20k pairs in total from four websites: Statista, Pew, Our World in Data, and OECD.


Math reasoning

We incorporate numerical reasoning knowledge into MatCha by learning math reasoning skills from textual math datasets. We use two existing textual math reasoning datasets, MATH and DROP for pre-training. MATH is synthetically created, containing two million training examples per module (type) of questions. DROP is a reading-comprehension–style QA dataset where the input is a paragraph context and a question.

To solve questions in DROP, the model needs to read the paragraph, extract relevant numbers and perform numerical computation. We found both datasets to be complementary. MATH contains a large number of questions across different categories, which helps us identify math operations needed to explicitly inject into the model. DROP’s reading-comprehension format resembles the typical QA format wherein models simultaneously perform information extraction and reasoning. In practice, we render inputs of both datasets into images. The model is trained to decode the answer.

To improve the math reasoning skills of MatCha we incorporate examples from MATH and DROP into the pre-training objective, by rendering the input text as images.

End-to-end results

We use a Pix2Struct model backbone, which is an image-to-text transformer tailored for website understanding, and pre-train it with the two tasks described above. We demonstrate the strengths of MatCha by fine-tuning it on several visual language tasks — tasks involving charts and plots for question answering and summarization where no access to the underlying table is possible. MatCha surpasses previous models’ performance by a large margin and also outperforms the previous state of the art, which assumes access to underlying tables.

In the figure below, we first evaluate two baseline models that incorporate information from an OCR pipeline, which until recently was the standard approach for working with charts. The first is based on T5, the second on VisionTaPas. We also compare against PaLI-17B, which is a large (~1000 times larger than the other models) image plus text-to-text transformer trained on a diverse set of tasks but with limited capabilities for reading text and other forms of visual language. Finally, we report the Pix2Struct and MatCha model results.

Experimental results on two chart QA benchmarks ChartQA & PlotQA (using relaxed accuracy) and a chart summarization benchmark chart-to-text (using BLEU4). Matcha surpasses the state of the art by a large margin on QA, compared to larger models, and matches these larger models on summarization.

For QA datasets, we use the official relaxed accuracy metric that allows for small relative errors in numerical outputs. For chart-to-text summarization, we report BLEU scores. MatCha achieves noticeably improved results compared to baselines for question answering, and comparable results to PaLI in summarization, where large size and extensive long text/captioning generation pre-training are advantageous for this kind of long-form text generation.


Derendering plus large language model chains

While extremely performant for their number of parameters, particularly on extractive tasks, we observed that fine-tuned MatCha models could still struggle with end-to-end complex reasoning (e.g., mathematical operations involving large numbers or multiple steps). Thus, we also propose a two-step method to tackle this: 1) a model reads a chart, then outputs the underlying table, 2) a large language model (LLM) reads this output and then tries to answer the question solely based on the textual input.

For the first model, we fine-tuned MatCha solely on the chart-to-table task, increasing the output sequence length to guarantee it could recover all or most of the information in the chart. DePlot is the resulting model. In the second stage, any LLM (such as FlanPaLM or Codex) can be used for the task, and we can rely on the standard methods to increase performance on LLMs, for example chain-of-thought and self-consistency. We also experimented with program-of-thoughts where the model produces executable Python code to offload complex computations.

An illustration of the DePlot+LLM method. This is a real example using FlanPaLM and Codex. The blue boxes are input to the LLM and the red boxes contain the answer generated by the LLMs. We highlight some of the key reasoning steps in each answer.

As shown in the example above, the DePlot model in combination with LLMs outperforms fine-tuned models by a significant margin, especially so in the human-sourced portion of ChartQA, where the questions are more natural but demand more difficult reasoning. Furthermore, DePlot+LLM can do so without access to any training data.

We have released the new models and code at our GitHub repo, where you can try it out yourself in colab. Checkout the papers for MatCha and DePlot for more details on the experimental results. We hope that our results can benefit the research community and make the information in charts and plots more accessible to everyone.


Acknowledgements

This work was carried out by Fangyu Liu, Julian Martin Eisenschlos, Francesco Piccinno, Syrine Krichene, Chenxi Pang, Kenton Lee, Mandar Joshi, Wenhu Chen and Yasemin Altun from our Language Team as part of Fangyu's internship project. Nigel Collier from Cambridge also was a collaborator. We would like to thank Joshua Howland, Alex Polozov, Shrestha Basu Mallick, Massimo Nicosia and William Cohen for their valuable comments and suggestions.

Source: Google AI Blog


Foundation models for reasoning on charts

Visual language is the form of communication that relies on pictorial symbols outside of text to convey information. It is ubiquitous in our digital life in the form of iconography, infographics, tables, plots, and charts, extending to the real world in street signs, comic books, food labels, etc. For that reason, having computers better understand this type of media can help with scientific communication and discovery, accessibility, and data transparency.

While computer vision models have made tremendous progress using learning-based solutions since the advent of ImageNet, the focus has been on natural images, where all sorts of tasks, such as classification, visual question answering (VQA), captioning, detection and segmentation, have been defined, studied and in some cases advanced to reach human performance. However, visual language has not garnered a similar level of attention, possibly because of the lack of large-scale training sets in this space. But over the last few years, new academic datasets have been created with the goal of evaluating question answering systems on visual language images, like PlotQA, InfographicsVQA, and ChartQA.

Example from ChartQA. Answering the question requires reading the information and computing the sum and the difference.

Existing models built for these tasks relied on integrating optical character recognition (OCR) information and their coordinates into larger pipelines but the process is error prone, slow, and generalizes poorly. The prevalence of these methods was because existing end-to-end computer vision models based on convolutional neural networks (CNNs) or transformers pre-trained on natural images could not be easily adapted to visual language. But existing models are ill-prepared for the challenges in answering questions on charts, including reading the relative height of bars or the angle of slices in pie charts, understanding axis scales, correctly mapping pictograms with their legend values with colors, sizes and textures, and finally performing numerical operations with the extracted numbers.

In light of these challenges, we propose “MatCha: Enhancing Visual Language Pretraining with Math Reasoning and Chart Derendering”. MatCha, which stands for math and charts, is a pixels-to-text foundation model (a pre-trained model with built-in inductive biases that can be fine-tuned for multiple applications) trained on two complementary tasks: (a) chart de-rendering and (b) math reasoning. In chart de-rendering, given a plot or chart, the image-to-text model is required to generate its underlying data table or the code used to render it. For math reasoning pre-training, we pick textual numerical reasoning datasets and render the input into images, which the image-to-text model needs to decode for answers. We also propose “DePlot: One-shot visual language reasoning by plot-to-table translation”, a model built on top of MatCha for one-shot reasoning on charts via translation to tables. With these methods we surpass the previous state of the art in ChartQA by more than 20% and match the best summarization systems that have 1000 times more parameters. Both papers will be presented at ACL2023.


Chart de-rendering

Plots and charts are usually generated by an underlying data table and a piece of code. The code defines the overall layout of the figure (e.g., type, direction, color/shape scheme) and the underlying data table establishes the actual numbers and their groupings. Both the data and code are sent to a compiler/rendering engine to create the final image. To understand a chart, one needs to discover the visual patterns in the image and effectively parse and group them to extract the key information. Reversing the plot rendering process demands all such capabilities and can thus serve as an ideal pre-training task.

A chart created from a table in the Airbus A380 Wikipedia page using random plotting options. The pre-training task for MatCha consists of recovering the source table or the source code from the image.

In practice, it is challenging to simultaneously obtain charts, their underlying data tables, and their rendering code. To collect sufficient pre-training data, we independently accumulate [chart, code] and [chart, table] pairs. For [chart, code], we crawl all GitHub IPython notebooks with appropriate licenses and extract blocks with figures. A figure and the code block right before it are saved as a [chart, code] pair. For [chart, table] pairs, we explored two sources. For the first source, synthetic data, we manually write code to convert web-crawled Wikipedia tables from the TaPas codebase to charts. We sampled from and combined several plotting options depending on the column types. In addition, we also add [chart, table] pairs generated in PlotQA to diversify the pre-training corpus. The second source is web-crawled [chart, table] pairs. We directly use the [chart, table] pairs crawled in the ChartQA training set, containing around 20k pairs in total from four websites: Statista, Pew, Our World in Data, and OECD.


Math reasoning

We incorporate numerical reasoning knowledge into MatCha by learning math reasoning skills from textual math datasets. We use two existing textual math reasoning datasets, MATH and DROP for pre-training. MATH is synthetically created, containing two million training examples per module (type) of questions. DROP is a reading-comprehension–style QA dataset where the input is a paragraph context and a question.

To solve questions in DROP, the model needs to read the paragraph, extract relevant numbers and perform numerical computation. We found both datasets to be complementary. MATH contains a large number of questions across different categories, which helps us identify math operations needed to explicitly inject into the model. DROP’s reading-comprehension format resembles the typical QA format wherein models simultaneously perform information extraction and reasoning. In practice, we render inputs of both datasets into images. The model is trained to decode the answer.

To improve the math reasoning skills of MatCha we incorporate examples from MATH and DROP into the pre-training objective, by rendering the input text as images.

End-to-end results

We use a Pix2Struct model backbone, which is an image-to-text transformer tailored for website understanding, and pre-train it with the two tasks described above. We demonstrate the strengths of MatCha by fine-tuning it on several visual language tasks — tasks involving charts and plots for question answering and summarization where no access to the underlying table is possible. MatCha surpasses previous models’ performance by a large margin and also outperforms the previous state of the art, which assumes access to underlying tables.

In the figure below, we first evaluate two baseline models that incorporate information from an OCR pipeline, which until recently was the standard approach for working with charts. The first is based on T5, the second on VisionTaPas. We also compare against PaLI-17B, which is a large (~1000 times larger than the other models) image plus text-to-text transformer trained on a diverse set of tasks but with limited capabilities for reading text and other forms of visual language. Finally, we report the Pix2Struct and MatCha model results.

Experimental results on two chart QA benchmarks ChartQA & PlotQA (using relaxed accuracy) and a chart summarization benchmark chart-to-text (using BLEU4). Matcha surpasses the state of the art by a large margin on QA, compared to larger models, and matches these larger models on summarization.

For QA datasets, we use the official relaxed accuracy metric that allows for small relative errors in numerical outputs. For chart-to-text summarization, we report BLEU scores. MatCha achieves noticeably improved results compared to baselines for question answering, and comparable results to PaLI in summarization, where large size and extensive long text/captioning generation pre-training are advantageous for this kind of long-form text generation.


Derendering plus large language model chains

While extremely performant for their number of parameters, particularly on extractive tasks, we observed that fine-tuned MatCha models could still struggle with end-to-end complex reasoning (e.g., mathematical operations involving large numbers or multiple steps). Thus, we also propose a two-step method to tackle this: 1) a model reads a chart, then outputs the underlying table, 2) a large language model (LLM) reads this output and then tries to answer the question solely based on the textual input.

For the first model, we fine-tuned MatCha solely on the chart-to-table task, increasing the output sequence length to guarantee it could recover all or most of the information in the chart. DePlot is the resulting model. In the second stage, any LLM (such as FlanPaLM or Codex) can be used for the task, and we can rely on the standard methods to increase performance on LLMs, for example chain-of-thought and self-consistency. We also experimented with program-of-thoughts where the model produces executable Python code to offload complex computations.

An illustration of the DePlot+LLM method. This is a real example using FlanPaLM and Codex. The blue boxes are input to the LLM and the red boxes contain the answer generated by the LLMs. We highlight some of the key reasoning steps in each answer.

As shown in the example above, the DePlot model in combination with LLMs outperforms fine-tuned models by a significant margin, especially so in the human-sourced portion of ChartQA, where the questions are more natural but demand more difficult reasoning. Furthermore, DePlot+LLM can do so without access to any training data.

We have released the new models and code at our GitHub repo, where you can try it out yourself in colab. Checkout the papers for MatCha and DePlot for more details on the experimental results. We hope that our results can benefit the research community and make the information in charts and plots more accessible to everyone.


Acknowledgements

This work was carried out by Fangyu Liu, Julian Martin Eisenschlos, Francesco Piccinno, Syrine Krichene, Chenxi Pang, Kenton Lee, Mandar Joshi, Wenhu Chen and Yasemin Altun from our Language Team as part of Fangyu's internship project. Nigel Collier from Cambridge also was a collaborator. We would like to thank Joshua Howland, Alex Polozov, Shrestha Basu Mallick, Massimo Nicosia and William Cohen for their valuable comments and suggestions.

Source: Google AI Blog


Resolving code review comments with ML

Code-change reviews are a critical part of the software development process at scale, taking a significant amount of the code authors’ and the code reviewers’ time. As part of this process, the reviewer inspects the proposed code and asks the author for code changes through comments written in natural language. At Google, we see millions of reviewer comments per year, and authors require an average of ~60 minutes active shepherding time between sending changes for review and finally submitting the change. In our measurements, the required active work time that the code author must do to address reviewer comments grows almost linearly with the number of comments. However, with machine learning (ML), we have an opportunity to automate and streamline the code review process, e.g., by proposing code changes based on a comment’s text.

Today, we describe applying recent advances of large sequence models in a real-world setting to automatically resolve code review comments in the day-to-day development workflow at Google (publication forthcoming). As of today, code-change authors at Google address a substantial amount of reviewer comments by applying an ML-suggested edit. We expect that to reduce time spent on code reviews by hundreds of thousands of hours annually at Google scale. Unsolicited, very positive feedback highlights that the impact of ML-suggested code edits increases Googlers' productivity and allows them to focus on more creative and complex tasks.


Predicting the code edit

We started by training a model that predicts code edits needed to address reviewer comments. The model is pre-trained on various coding tasks and related developer activities (e.g., renaming a variable, repairing a broken build, editing a file). It’s then fine-tuned for this specific task with reviewed code changes, the reviewer comments, and the edits the author performed to address those comments.

An example of an ML-suggested edit of refactorings that are spread within the code.

Google uses a monorepo, a single repository for all of its software artifacts, which allows our training dataset to include all unrestricted code used to build Google's most recent software, as well as previous versions.

To improve the model quality, we iterated on the training dataset. For example, we compared the model performance for datasets with a single reviewer comment per file to datasets with multiple comments per file, and experimented with classifiers to clean up the training data based on a small, curated dataset to choose the model with the best offline precision and recall metrics.


Serving infrastructure and user experience

We designed and implemented the feature on top of the trained model, focusing on the overall user experience and developer efficiency. As part of this, we explored different user experience (UX) alternatives through a series of user studies. We then refined the feature based on insights from an internal beta (i.e., a test of the feature in development) including user feedback (e.g., a “Was this helpful?” button next to the suggested edit).

The final model was calibrated for a target precision of 50%. That is, we tuned the model and the suggestions filtering, so that 50% of suggested edits on our evaluation dataset are correct. In general, increasing the target precision reduces the number of shown suggested edits, and decreasing the target precision leads to more incorrect suggested edits. Incorrect suggested edits take the developers time and reduce the developers’ trust in the feature. We found that a target precision of 50% provides a good balance.

At a high level, for every new reviewer comment, we generate the model input in the same format that is used for training, query the model, and generate the suggested code edit. If the model is confident in the prediction and a few additional heuristics are satisfied, we send the suggested edit to downstream systems. The downstream systems, i.e., the code review frontend and the integrated development environment (IDE), expose the suggested edits to the user and log user interactions, such as preview and apply events. A dedicated pipeline collects these logs and generates aggregate insights, e.g., the overall acceptance rates as reported in this blog post.

Architecture of the ML-suggested edits infrastructure. We process code and infrastructure from multiple services, get the model predictions and surface the predictions in the code review tool and IDE.

The developer interacts with the ML-suggested edits in the code review tool and the IDE. Based on insights from the user studies, the integration into the code review tool is most suitable for a streamlined review experience. The IDE integration provides additional functionality and supports 3-way merging of the ML-suggested edits (left in the figure below) in case of conflicting local changes on top of the reviewed code state (right) into the merge result (center).

3-way-merge UX in IDE.

Results

Offline evaluations indicate that the model addresses 52% of comments with a target precision of 50%. The online metrics of the beta and the full internal launch confirm these offline metrics, i.e., we see model suggestions above our target model confidence for around 50% of all relevant reviewer comments. 40% to 50% of all previewed suggested edits are applied by code authors.

We used the “not helpful” feedback during the beta to identify recurring failure patterns of the model. We implemented serving-time heuristics to filter these and, thus, reduce the number of shown incorrect predictions. With these changes, we traded quantity for quality and observed an increased real-world acceptance rate.

Code review tool UX. The suggestion is shown as part of the comment and can be previewed, applied and rated as helpful or not helpful.

Our beta launch showed a discoverability challenge: code authors only previewed ~20% of all generated suggested edits. We modified the UX and introduced a prominent “Show ML-edit” button (see the figure above) next to the reviewer comment, leading to an overall preview rate of ~40% at launch. We additionally found that suggested edits in the code review tool are often not applicable due to conflicting changes that the author did during the review process. We addressed this with a button in the code review tool that opens the IDE in a merge view for the suggested edit. We now observe that more than 70% of these are applied in the code review tool and fewer than 30% are applied in the IDE. All these changes allowed us to increase the overall fraction of reviewer comments that are addressed with an ML-suggested edit by a factor of 2 from beta to the full internal launch. At Google scale, these results help automate the resolution of hundreds of thousands of comments each year.

Suggestions filtering funnel.

We see ML-suggested edits addressing a wide range of reviewer comments in production. This includes simple localized refactorings and refactorings that are spread within the code, as shown in the examples throughout the blog post above. The feature addresses longer and less formally-worded comments that require code generation, refactorings and imports.

Example of a suggestion for a longer and less formally worded comment that requires code generation, refactorings and imports.

The model can also respond to complex comments and produce extensive code edits (shown below). The generated test case follows the existing unit test pattern, while changing the details as described in the comment. Additionally, the edit suggests a comprehensive name for the test reflecting the test semantics.

Example of the model's ability to respond to complex comments and produce extensive code edits.

Conclusion and future work

In this post, we introduced an ML-assistance feature to reduce the time spent on code review related changes. At the moment, a substantial amount of all actionable code review comments on supported languages are addressed with applied ML-suggested edits at Google. A 12-week A/B experiment across all Google developers will further measure the impact of the feature on the overall developer productivity.

We are working on improvements throughout the whole stack. This includes increasing the quality and recall of the model and building a more streamlined experience for the developer with improved discoverability throughout the review process. As part of this, we are investigating the option of showing suggested edits to the reviewer while they draft comments and expanding the feature into the IDE to enable code-change authors to get suggested code edits for natural-language commands.


Acknowledgements

This is the work of many people in Google Core Systems & Experiences team, Google Research, and DeepMind. We'd like to specifically thank Peter Choy for bringing the collaboration together, and all of our team members for their key contributions and useful advice, including Marcus Revaj, Gabriela Surita, Maxim Tabachnyk, Jacob Austin, Nimesh Ghelani, Dan Zheng, Peter Josling, Mariana Stariolo, Chris Gorgolewski, Sascha Varkevisser, Katja Grünwedel, Alberto Elizondo, Tobias Welp, Paige Bailey, Pierre-Antoine Manzagol, Pascal Lamblin, Chenjie Gu, Petros Maniatis, Henryk Michalewski, Sara Wiltberger, Ambar Murillo, Satish Chandra, Madhura Dudhgaonkar, Niranjan Tulpule, Zoubin Ghahramani, Juanjo Carin, Danny Tarlow, Kevin Villela, Stoyan Nikolov, David Tattersall, Boris Bokowski, Kathy Nix, Mehdi Ghissassi, Luis C. Cobo, Yujia Li, David Choi, Kristóf Molnár, Vahid Meimand, Amit Patel, Brett Wiltshire, Laurent Le Brun, Mingpan Guo, Hermann Loose, Jonas Mattes, Savinee Dancs.

Source: Google AI Blog


Using reinforcement learning for dynamic planning in open-ended conversations

As virtual assistants become ubiquitous, users increasingly interact with them to learn about new topics or obtain recommendations and expect them to deliver capabilities beyond narrow dialogues of one or two turns. Dynamic planning, namely the capability to look ahead and replan based on the flow of the conversation, is an essential ingredient for the making of engaging conversations with the deeper, open-ended interactions that users expect.

While large language models (LLMs) are now beating state-of-the-art approaches in many natural language processing benchmarks, they are typically trained to output the next best response, rather than planning ahead, which is required for multi-turn interactions. However, in the past few years, reinforcement learning (RL) has delivered incredible results addressing specific problems that involve dynamic planning, such as winning games and protein folding.

Today, we are sharing our recent advances in dynamic planning for human-to-assistant conversations, in which we enable an assistant to plan a multi-turn conversation towards a goal and adapt that plan in real-time by adopting an RL-based approach. Here we look at how to improve long interactions by applying RL to compose answers based on information extracted from reputable sources, rather than relying on content generated by a language model. We expect that future versions of this work could combine LLMs and RL in multi-turn dialogues. The deployment of RL “in the wild” in a large-scale dialogue system proved a formidable challenge due to the modeling complexity, tremendously large state and action spaces, and significant subtlety in designing reward functions.


What is dynamic planning?

Many types of conversations, from gathering information to offering recommendations, require a flexible approach and the ability to modify the original plan for the conversation based on its flow. This ability to shift gears in the middle of a conversation is known as dynamic planning, as opposed to static planning, which refers to a more fixed approach. In the conversation below, for example, the goal is to engage the user by sharing interesting facts about cool animals. To begin, the assistant steers the conversation to sharks via a sound quiz. Given the user's lack of interest in sharks, the assistant then develops an updated plan and pivots the conversation to sea lions, lions, and then cheetahs.

The assistant dynamically modifies its original plan to talk about sharks and shares facts about other animals.

Dynamic composition

To cope with the challenge of conversational exploration, we separate the generation of assistant responses into two parts: 1) content generation, which extracts relevant information from reputable sources, and 2) flexible composition of such content into assistant responses. We refer to this two-part approach as dynamic composition. Unlike LLM methods, this approach gives the assistant the ability to fully control the source, correctness, and quality of the content that it may offer. At the same time, it can achieve flexibility via a learned dialogue manager that selects and combines the most appropriate content.

In an earlier paper, “Dynamic Composition for Conversational Domain Exploration”, we describe a novel approach which consists of: (1) a collection of content providers, which offer candidates from different sources, such as news snippets, knowledge graph facts, and questions; (2) a dialogue manager; and (3) a sentence fusion module. Each assistant response is incrementally constructed by the dialogue manager, which selects candidates proposed by the content providers. The selected sequence of utterances is then fused into a cohesive response.


Dynamic planning using RL

At the core of the assistant response composition loop is a dialogue manager trained using off-policy RL, namely an algorithm that evaluates and improves a policy that is different from the policy used by the agent (in our case, the latter is based on a supervised model). Applying RL to dialogue management presents several challenges, including a large state space (as the state represents the conversation state, which needs to account for the whole conversation history) and an effectively unbounded action space (that may include all existing words or sentences in natural language).

We address these challenges using a novel RL construction. First, we leverage powerful supervised models — specifically, recurrent neural networks (RNNs) and transformers — to provide a succinct and effective dialogue state representation. These state encoders are fed with the dialogue history, composed of a sequence of user and assistant turns, and output a representation of the dialogue state in the form of a latent vector.

Second, we use the fact that a relatively small set of reasonable candidate utterances or actions can be generated by content providers at each conversation turn, and limit the action space to these. Whereas the action space is typically fixed in RL settings, because all states share the same action space, ours is a non-standard space in which the candidate actions may differ with each state, since content providers generate different actions depending on the dialogue context. This puts us in the realm of stochastic action sets, a framework that formalizes cases where the set of actions available in each state is governed by an exogenous stochastic process, which we address using Stochastic Action Q-Learning, a variant of the Q-learning approach. Q-learning is a popular off-policy RL algorithm, which does not require a model of the environment to evaluate and improve the policy. We trained our model on a corpus of crowd-compute–rated conversations obtained using a supervised dialogue manager.

Given the current dialogue history and a new user query, content providers generate candidates from which the assistant selects one. This process runs in a loop, and at the end the selected utterances are fused into a cohesive response.

Reinforcement learning model evaluation

We compared our RL dialogue manager with a launched supervised transformer model in an experiment using Google Assistant, which conversed with users about animals. A conversation starts when a user triggers the experience by asking an animal-related query (e.g., “How does a lion sound?”). The experiment was conducted using an A/B testing protocol, in which a small percentage of Assistant users were randomly sampled to interact with our RL-based assistant while other users interacted with the standard assistant.

We found that the RL dialogue manager conducts longer, more engaging conversations. It increases conversation length by 30% while improving user engagement metrics. We see an increase of 8% in cooperative responses to the assistant’s questions — e.g., “Tell me about lions,” in response to “Which animal do you want to hear about next?” Although there is also a large increase in nominally “non-cooperative” responses (e.g., “No,” as a reply to a question proposing additional content, such as “Do you want to hear more?”), this is expected as the RL agent takes more risks by asking pivoting questions. While a user may not be interested in the conversational direction proposed by the assistant (e.g., pivoting to another animal), the user will often continue to engage in a dialogue about animals.

From the non-cooperative user response in the 3rd turn (“No.”) and the query “Make a dog sound,” in the 5th turn, the assistant recognizes that the user is mostly interested in animal sounds and modifies its plan, providing sounds and sound quizzes.

In addition, some user queries contain explicit positive (e.g., “Thank you, Google,” or “I’m happy.”) or negative (e.g., “Shut up,” or “Stop.”) feedback. While an order of magnitude fewer than other queries, they offer a direct measure of user (dis)satisfaction. The RL model increases explicit positive feedback by 32% and reduces negative feedback by 18%.


Learned dynamic planning characteristics and strategies

We observe several characteristics of the (unseen) RL plan to improve user engagement while conducting longer conversations. First, the RL-based assistant ends 20% more turns in questions, prompting the user to choose additional content. It also better harnesses content diversity, including facts, sounds, quizzes, yes/no questions, open questions, etc. On average, the RL assistant uses 26% more distinct content providers per conversation than the supervised model.

Two observed RL planning strategies are related to the existence of sub-dialogues with different characteristics. Sub-dialogues about animal sounds are poorer in content and exhibit entity pivoting at every turn (i.e., after playing the sound of a given animal, we can either suggest the sound of a different animal or quiz the user about other animal sounds). In contrast, sub-dialogues involving animal facts typically contain richer content and have greater conversation depth. We observe that RL favors the richer experience of the latter, selecting 31% more fact-related content. Lastly, when restricting analysis to fact-related dialogues, the RL assistant exhibits 60% more focus-pivoting turns, that is, conversational turns that change the focus of the dialogue.

Below, we show two example conversations, one conducted by the supervised model (left) and the second by the RL model (right), in which the first three user turns are identical. With a supervised dialogue manager, after the user declined to hear about “today’s animal”, the assistant pivots back to animal sounds to maximize the immediate user satisfaction. While the conversation conducted by the RL model begins identically, it exhibits a different planning strategy to optimize the overall user engagement, introducing more diverse content, such as fun facts.

In the left conversation, conducted by the supervised model, the assistant maximizes the immediate user satisfaction. The right conversation, conducted by the RL model, shows different planning strategies to optimize the overall user engagement.

Future research and challenges

In the past few years, LLMs trained for language understanding and generation have demonstrated impressive results across multiple tasks, including dialogue. We are now exploring the use of an RL framework to empower LLMs with the capability of dynamic planning so that they can dynamically plan ahead and delight users with a more engaging experience.


Acknowledgements

The work described is co-authored by: Moonkyung Ryu, Yinlam Chow, Orgad Keller, Ido Greenberg, Avinatan Hassidim, Michael Fink, Yossi Matias, Idan Szpektor and Gal Elidan. We would like to thank: Roee Aharoni, Moran Ambar, John Anderson, Ido Cohn, Mohammad Ghavamzadeh, Lotem Golany, Ziv Hodak, Adva Levin, Fernando Pereira, Shimi Salant, Shachar Shimoni, Ronit Slyper, Ariel Stolovich, Hagai Taitelbaum, Noam Velan, Avital Zipori and the CrowdCompute team led by Ashwin Kakarla. We thank Sophie Allweis for her feedback on this blogpost and Tom Small for the visualization.

Source: Google AI Blog