Search

US-12619886-B2 - Frozen model adaptation through soft prompt transfer

US12619886B2US 12619886 B2US12619886 B2US 12619886B2US-12619886-B2

Abstract

Systems and methods for prompt tuning can utilize previously-learned prompts for the initialization of tuning for prompts on different tasks that may differ from the task associated with the previously-learned prompt. The prompt being utilized for initialization can be a generic prompt and/or may be a prompt selected based on a determined similarity between two or more task embeddings.

Inventors

  • Tu Thanh Vu
  • Daniel Matthew Cer
  • Noah Constant
  • Brian David Lester
  • Rami Al-Rfou

Assignees

  • GOOGLE LLC

Dates

Publication Date
20260505
Application Date
20220713

Claims (20)

  1. 1 . A computing system for soft prompt transfer-learning, the computing system comprising: one or more processors; and one or more non-transitory computer-readable media that collectively store instructions that, when executed by the one or more processors, cause the computing system to perform operations, the operations comprising: obtaining target task data, wherein the target task data is associated with a target task for a machine-learned model, wherein the target task data comprises one or more target training examples and one or more target training labels; processing the target task data to generate a target task embedding; obtaining, based on an embedding-based search and for initializing target prompt generation, a source prompt from a prompt database based on the target task embedding based on a learned distribution associated with an embedding space, wherein the source prompt is a general source prompt associated with a group of previously-learned tasks, and wherein the source prompt is associated with a source embedding, wherein the source prompt comprises a set of parameters trained for conditioning a pre-trained machine-learned model for the group of previously-learned tasks, and wherein the prompt database comprises a plurality of different pre-trained soft prompts associated with a plurality of different respective task embeddings, and wherein the source prompt is obtained from the prompt database based on determining the source embedding is associated with the target task embedding; processing the source prompt and the target task data with the pre-trained machine-learned model to generate one or more outputs; generating a target prompt for the target task based on the source prompt and the one or more outputs, wherein generating the target prompt comprises adjusting one or more parameters of set of parameters of the source prompt based on an evaluation of the one or more outputs, wherein the one or more parameters of the set of parameters are adjusted without adjusting parameters of the pre-trained machine-learned model, and wherein the one or more parameters are adjusted to train the set of parameters for conditioning the pre-trained machine-learned model to perform the target task; and storing the target task embedding and the target prompt in the prompt database.
  2. 2 . The system of claim 1 , wherein generating the target prompt comprises: evaluating a loss function based on the one or more outputs; and adjusting the one or more parameters of the source prompt based on the loss function.
  3. 3 . The system of claim 1 , wherein the operations further comprise: obtaining input data; and processing the input data and the target prompt with the pre-trained machine-learned model to generate a target task output, wherein the target task output is associated with the target task.
  4. 4 . The system of claim 1 , wherein processing the target task data to generate the target task embedding comprises learning one or more embedding parameters based at least in part on the target task data.
  5. 5 . The system of claim 1 , wherein obtaining the source prompt from the prompt database based on the target task embedding comprises: determining the target task embedding is associated with the source embedding; and obtaining the source prompt associated with the source embedding.
  6. 6 . The system of claim 1 , wherein the source prompt was pre-trained on a plurality of different training datasets associated with a plurality of different tasks.
  7. 7 . The system of claim 1 , wherein the operations further comprise: obtaining a first source task dataset, wherein the first source task dataset is associated with a first task; generating a first source embedding based on the first source task dataset by partially training a plurality of first source parameters; generating a first source prompt based on the first source task dataset by further training the plurality of first source parameters; and storing the first source prompt and the first source embedding in the prompt database.
  8. 8 . The system of claim 7 , wherein the operations further comprise: obtaining a second source task dataset, wherein the second source task dataset is associated with a second task; generating a second source embedding based on the second source task dataset by partially training a plurality of second source parameters; generating a second source prompt based on the second source task dataset by further training the plurality of second source parameters; and storing the second source prompt and the second source embedding in the prompt database.
  9. 9 . The system of claim 8 , wherein obtaining the source prompt from the prompt database based on the target task embedding comprises: determining the first source embedding is more similar to the target task embedding than the second source embedding; determining the first source prompt is the source prompt based on the first source embedding being more similar to the target task embedding than the second source embedding; and obtaining the first source prompt from the prompt database.
  10. 10 . The system of claim 1 , wherein obtaining the source prompt from the prompt database based on the target task embedding comprises: determining a particular source task embedding associated with the target task embedding based on a learned distribution associated with an embedding space.
  11. 11 . A computer-implemented method for prompt tuning, the method comprising: obtaining, by a computing system comprising one or more processors, a first task dataset, wherein the first task dataset is associated with a group of first tasks; processing, by the computing system, the first task dataset to generate a first source task embedding; training, by the computing system, a first source prompt based on the first task dataset, wherein training the first source prompt comprises: processing, by the computing system, the first task dataset and a set of parameters with a pre-trained machine-learned model to generate a first task output; and adjusting, by the computing system, one or more parameters of the set of parameters based on the first task output, wherein the one or more parameters are adjusted to train the set of parameters for conditioning the pre-trained machine-learned model to perform a respective task of the group of first tasks, wherein the first source prompt is a general source prompt trained for conditioning the pre-trained machine-learned model for the group of first tasks; storing, by the computing system, the first source task embedding and the first source prompt in a prompt database; obtaining, by the computing system, a target task dataset; processing, by the computing system, the target task dataset to generate a target task embedding; obtaining, based on an embedding-based search and for initializing target prompt generation, a source prompt from a prompt database based on the target task embedding based on a learned distribution associated with an embedding space, wherein obtaining the source prompt comprises: determining, by the computing system, the target task embedding is associated with the first source task embedding; processing, by the computing system, the target task dataset and the first source prompt with the pre-trained machine-learned model to generate a target task output; adjusting, by the computing system, one or more parameters of the first source prompt based on the target task output to generate a target task prompt, wherein the one or more parameters of the set of parameters are adjusted without adjusting parameters of the pre-trained machine-learned model, and wherein the one or more parameters are adjusted to train the set of parameters for conditioning the pre-trained machine-learned model to perform the target task; and storing, by the computing system, the target task embedding and the target task prompt in the prompt database.
  12. 12 . The method of claim 11 , wherein the pre-trained machine-learned model comprises a large frozen model, wherein a plurality of pre-trained parameters for the pre-trained machine-learned model are fixed during prompt tuning.
  13. 13 . The method of claim 11 , wherein the respective task of the group of first tasks is descriptive of a text completion task.
  14. 14 . The method of claim 11 , wherein the target task dataset is associated with a target task, and wherein the target task is descriptive of a sentiment classification task.
  15. 15 . The method of claim 11 , wherein determining the target task embedding is associated with the first source task embedding comprises: generating a similarity score based on a similarity between the target task embedding and the first source task embedding.
  16. 16 . One or more non-transitory computer-readable media that collectively store instructions that, when executed by one or more computing devices, cause the one or more computing devices to perform operations, the operations comprising: obtaining target task data, wherein the target task data is associated with a target task for a machine-learned model; obtaining, based on an embedding-based search and for initializing target prompt generation, a source prompt from a prompt database based on the target task data based on a learned distribution associated with an embedding space, wherein the source prompt is a general source prompt associated with a group of previously-learned tasks, wherein the source prompt comprises a set of learned parameters representative of a group of source tasks, and wherein the source prompt is associated with a source embedding, wherein the source prompt comprises the set of learned parameters trained for conditioning a pre-trained machine-learned model for the group of previously-learned tasks, and wherein the prompt database comprises a plurality of different pre-trained soft prompts associated with a plurality of different respective task embeddings, and wherein the source prompt is obtained from the prompt database based on determining the source embedding is associated with a target task embedding; processing the source prompt and the target task data with a pre-trained machine-learned model to generate one or more outputs, wherein the pre-trained machine-learned model comprises a frozen language model; evaluating a loss function based on the one or more outputs; adjusting one or more parameters of the source prompt based on the loss function to generate a target prompt, wherein the one or more parameters of the set of parameters are adjusted without adjusting parameters of the pre-trained machine-learned model, and wherein the one or more parameters are adjusted to train the set of parameters for conditioning the pre-trained machine-learned model to perform the target task; and storing the target task embedding and the target prompt in the prompt database.
  17. 17 . The one or more non-transitory computer-readable media of claim 16 , wherein obtaining the source prompt from the prompt database based on the target task data comprises: processing the target task data with an embedding model to generate the target task embedding; determining a nearest embedding neighbor for the target task embedding based on a plurality of embeddings stored in the prompt database; and determining the source prompt is associated with the nearest embedding neighbor.
  18. 18 . The one or more non-transitory computer-readable media of claim 16 , wherein the source embedding and the source prompt were generated by training a plurality of source parameters based on a source task dataset associated with the source tasks.
  19. 19 . The one or more non-transitory computer-readable media of claim 16 , wherein the target task comprises an image classification task, wherein the source tasks differ from the target task, and wherein the target prompt is configured to be processed with the pre-trained machine-learned model to perform a target task, and wherein one or more of the group of source tasks is configured to be processed with the pre-trained machine-learned model to perform the source task.
  20. 20 . The one or more non-transitory computer-readable media of claim 16 , wherein the plurality of different respective task embeddings of the prompt database are associated with a semantic space of tasks configured to cluster similar tasks.

Description

FIELD The present disclosure relates generally to prompt tuning initialized by a pre-trained soft prompt. More particularly, the present disclosure relates to transfer learning of a set of parameters for a target task based on a pre-trained set of parameters for a previously-learned task. BACKGROUND Large pre-trained models can provide realistic outputs (e.g., realistic natural language outputs). However, training and retraining the large machine-learned models can be computationally expensive as the models can include billions of parameters. Additionally, efforts to condition inputs with alternative techniques provide reduced quality results while being tedious. There are a plurality of tasks that the large pre-trained models may be useful for if trained or conditioned for the particular task. However, the training of the parameters of the model may not be feasible for general consumer computing devices. Therefore, the large pre-trained models may rely on being trained and retrained using the large computational resources that may not be readily accessible. SUMMARY Aspects and advantages of embodiments of the present disclosure will be set forth in part in the following description, or can be learned from the description, or can be learned through practice of the embodiments. One example aspect of the present disclosure is directed to a computing system for soft prompt transfer-learning. The computing system can include one or more processors and one or more non-transitory computer-readable media that collectively store instructions that, when executed by the one or more processors, cause the computing system to perform operations. The operations can include obtaining target task data. The target task data can be associated with a target task for a machine-learned model. In some implementations, the target task data can include one or more target training examples and one or more target training labels. The operations can include processing the target task data to generate a target task embedding. The operations can include obtaining a source prompt from a prompt database based on the target task embedding. The source prompt can be associated with a previously-learned task. In some implementations, the source prompt can be associated with a source embedding. The operations can include processing the source prompt and the target task data with a pre-trained machine-learned model to generate one or more outputs. The operations can include generating a target prompt for the target task based on the source prompt and the one or more outputs. In some implementations, generating the target prompt can include evaluating a loss function based on the one or more outputs and adjusting one or more parameters of the source prompt based on the loss function. The operations can include obtaining input data and processing the input data and the target prompt with the pre-trained machine-learned model to generate a target task output. In some implementations, the target task output can be associated with the target task. Processing the target task data to generate the target task embedding can include learning one or more embedding parameters based at least in part on the target task data. In some implementations, obtaining the source prompt from the prompt database based on the target task embedding can include determining the target task embedding is associated with the source embedding and obtaining the source prompt associated with the source embedding. In some implementations, the source prompt may have been pre-trained on a plurality of different training datasets associated with a plurality of different tasks. The operations can include obtaining a first source task dataset. The first source task dataset can be associated with a first task. The operations can include generating a first source embedding based on the first source task dataset by partially training a plurality of first source parameters, generating a first source prompt based on the first source task dataset by further training the plurality of first source parameters, and storing the first source prompt and the first source embedding in the prompt database. In some implementations, the operations can include obtaining a second source task dataset. The second source task dataset can be associated with a second task. The operations can include generating a second source embedding based on the second source task dataset by partially training a plurality of second source parameters, generating a second source prompt based on the second source task dataset by further training the plurality of second source parameters, and storing the second source prompt and the second source embedding in the prompt database. In some implementations, obtaining the source prompt from the prompt database based on the target task embedding can include determining the first source embedding is more similar to the target task embedding than the second source embedding, determi