US-12620212-B2 - Locked-model multimodal contrastive tuning
Abstract
A method may include obtaining a pretrained image encoder and a training sample comprising a training image and a training text string corresponding to the training image. The method may also include initializing a text encoder in an untrained state, determining, using the pretrained image encoder and based on the training image, a first latent representation of the training image, and determining, using the text encoder and based on the training text string, a second latent representation of the training text string. The method may further include determining a loss value based on the first latent representation and the second latent representation, updating, based on the loss value, one or more parameters of the text encoder while holding fixed parameters of the pretrained image encoder, and outputting the text encoder in a trained state.
Inventors
- Daniel Keysers
- Xiaohua Zhai
- Xiao Wang
- Lucas Beyer
- Basil Mustafa
- Andreas Steiner
- Alexander Kolesnikov
Assignees
- GOOGLE LLC
Dates
- Publication Date
- 20260505
- Application Date
- 20221031
Claims (20)
- 1 . A computer-implemented method comprising: obtaining (i) a pretrained image encoder and (ii) a training sample comprising a training image and a training text string corresponding to the training image; initializing a text encoder in an untrained state; determining, by the pretrained image encoder and based on the training image, a first latent representation of the training image; determining, by the text encoder and based on the training text string, a second latent representation of the training text string; determining a loss value based on the first latent representation as determined by the pretrained image encoder and the second latent representation as determined by the text encoder; updating, based on the loss value, one or more parameters of the text encoder while holding fixed all parameters of the pretrained image encoder throughout training of the text encoder; and outputting the text encoder in a trained state.
- 2 . The computer-implemented method of claim 1 , wherein the pretrained image encoder and the text encoder form a multimodal contrastive learning pair, and wherein updating the one or more parameters of the text encoder is configured to train the text encoder to determine latent representations that, for a given training sample, converge to latent representations determined by the pretrained image encoder.
- 3 . The computer-implemented method of claim 1 , wherein initializing the text encoder comprises: initializing parameters of the text encoder using substantially randomly selected values.
- 4 . The computer-implemented method of claim 1 , wherein a size of the first latent representation is equal to a size of the second latent representation.
- 5 . The computer-implemented method of claim 4 , wherein an output layer of the pretrained image encoder has a first size, and wherein the text encoder comprises a final projection layer configured to project an output of a penultimate layer of the text encoder to the first size.
- 6 . The computer-implemented method of claim 1 , wherein determining the loss value comprises: determining the loss value using a contrastive loss function configured to determine a similarity between the first latent representation as determined by the pretrained image encoder and the second latent representation as determined by the text encoder.
- 7 . The computer-implemented method of claim 6 , wherein updating the one or more parameters of the text encoder based on the loss value determined by the contrastive loss function is configured to train the text encoder to determine latent representations that (i), for training samples comprising matched image-text pairs, converge to latent representations determined by the pretrained image encoder and (ii), for training samples comprising unmatched image-text pairs, diverge from latent representations determined by the pretrained image encoder.
- 8 . The computer-implemented method of claim 1 , further comprising: obtaining a second training sample comprising the training image and a second training text string that does not correspond to the training image; determining, using the text encoder and based on the second training text string, a third latent representation of the second training text string; determining a second loss value based on the first latent representation and the third latent representation; and updating, based on the second loss value, one or more additional parameters of the text encoder while holding fixed all the parameters of the pretrained image encoder.
- 9 . The computer-implemented method of claim 1 , wherein obtaining the pretrained image encoder comprises: initializing an image encoder in a second untrained state; and training the image encoder using a training image data set and independently of the text encoder.
- 10 . The computer-implemented method of claim 1 , further comprising: obtaining a text query after updating the one or more parameters of the text encoder; generating, using the text encoder and based on the text query, a third latent representation of the text query; and retrieving one or more images, wherein each respective image of the one or more images is associated with a corresponding latent representation that (i) has been generated by the pretrained image encoder and (ii) has at least a threshold extent of similarity to the third latent representation.
- 11 . The computer-implemented method of claim 1 , further comprising: obtaining an image query; generating, using the pretrained image encoder and based on the image query, a third latent representation of the image query; and retrieving one or more text strings, wherein each respective text string of the one or more text strings is associated with a corresponding latent representation that (i) has been generated by the text encoder after updating the one or more parameters of the text encoder and (ii) has at least a threshold extent of similarity to the third latent representation.
- 12 . The computer-implemented method of claim 1 , further comprising: obtaining an image, a first text string, and a second text string; generating, using the pretrained image encoder and based on the image, a third latent representation of the image; generating, using the text encoder, (i) a fourth latent representation of the first text string based on the first text string and (ii) a fifth latent representation of the second text string based on the second text string; determining (i) a first similarity between the third latent representation and the fourth latent representation and (ii) a second similarity between the third latent representation and the fifth latent representation; determining that the first similarity exceeds the second similarity; and based on determining that the first similarity exceeds the second similarity, determining that the image belongs to a class corresponding to the first text string.
- 13 . The computer-implemented method of claim 1 , wherein determining the first latent representation of the training image comprises: precomputing the first latent representation by the pretrained image encoder prior to training of the text encoder.
- 14 . The computer-implemented method of claim 13 , wherein determining the loss value comprises: reusing, throughout the training of the text encoder, the first latent representation as precomputed by the pretrained image encoder.
- 15 . The computer-implemented method of claim 14 , wherein: obtaining the training sample comprises obtaining a plurality of training samples, wherein each respective training sample of the plurality of training samples comprises a respective training image and a respective training text string corresponding to the respective training image; determining the first latent representation of the training image comprises precomputing, for each respective training sample, by the pretrained image encoder, and based on the respective training image, a corresponding latent representation of the respective training image prior to training of the text encoder; and determining the loss value comprises reusing, throughout the training of the text encoder, the corresponding latent representation as precomputed by the pretrained image encoder for the respective training sample.
- 16 . A system comprising: a processor; and a non-transitory computer-readable medium having stored thereon instructions that, when executed by the processor, cause the processor to perform operations comprising: obtaining (i) a pretrained image encoder and (ii) a training sample comprising a training image and a training text string corresponding to the training image; initializing a text encoder in an untrained state; determining, by the pretrained image encoder and based on the training image, a first latent representation of the training image; determining, by the text encoder and based on the training text string, a second latent representation of the training text string; determining a loss value based on the first latent representation as determined by the pretrained image encoder and the second latent representation as determined by the text encoder; updating, based on the loss value, one or more parameters of the text encoder while holding fixed all parameters of the pretrained image encoder throughout training of the text encoder; and outputting the text encoder in a trained state.
- 17 . The system of claim 16 , wherein the pretrained image encoder and the text encoder form a multimodal contrastive learning pair, and wherein updating the one or more parameters of the text encoder is configured to train the text encoder to determine latent representations that, for a given training sample, converge to latent representations determined by the pretrained image encoder.
- 18 . The system of claim 16 , wherein determining the first latent representation of the training image comprises: precomputing the first latent representation by the pretrained image encoder prior to training of the text encoder.
- 19 . A computer-implemented method comprising: obtaining an image, a text string, a pretrained image encoder, and a text encoder, wherein the text encoder has been trained by a training process comprising: obtaining (i) the pretrained image encoder and (ii) a training sample comprising a training image and a training text string corresponding to the training image; initializing the text encoder in an untrained state; determining, by the pretrained image encoder and based on the training image, a first latent representation of the training image; determining, by the text encoder and based on the training text string, a second latent representation of the training text string; determining a loss value based on the first latent representation as determined by the pretrained image encoder and the second latent representation as determined by the text encoder; and updating, based on the loss value, one or more parameters of the text encoder while holding fixed all parameters of the pretrained image encoder throughout training of the text encoder; determining, by the pretrained image encoder and based on the image, a third latent representation of the image; determining, by the text encoder and based on the text string, a fourth latent representation of the text string; determining a similarity between the third latent representation and the fourth latent representation; and generating an output based on the similarity.
- 20 . A non-transitory computer-readable medium having stored thereon instructions that, when executed by a processor, cause the processor to perform operations comprising: obtaining an image, a text string, a pretrained image encoder, and a text encoder, wherein the text encoder has been trained by a training process comprising: obtaining (i) the pretrained image encoder and (ii) a training sample comprising a training image and a training text string corresponding to the training image; initializing the text encoder in an untrained state; determining, by the pretrained image encoder and based on the training image, a first latent representation of the training image; determining, by the text encoder and based on the training text string, a second latent representation of the training text string; determining a loss value based on the first latent representation as determined by the pretrained image encoder and the second latent representation as determined by the text encoder; and updating, based on the loss value, one or more parameters of the text encoder while holding fixed all parameters of the pretrained image encoder throughout training of the text encoder; determining, by the pretrained image encoder and based on the image, a third latent representation of the image; determining, by the text encoder and based on the text string, a fourth latent representation of the text string; determining a similarity between the third latent representation and the fourth latent representation; and generating an output based on the similarity.
Description
BACKGROUND Machine Learning models may be used to process various types of data, including images, video, time series, text, and/or point clouds, among other possibilities. Improvements in the machine learning models may allow the models to carry out the processing of data faster and/or utilize fewer computing resources for the processing. Improvements in the machine learning models may also allow the models to generate outputs that are relatively more accurate, precise, and/or otherwise improved. SUMMARY A first machine learning model and a second machine learning model, each configured to process a different type of data, may be trained using a contrastive learning process to generate similar latent representation for matched pairs of input samples of the different types of data and dissimilar latent representation for unmatched pairs of input samples of the different types of data. For example, the first machine learning model may be configured to generate latent representations of images, while the second machine learning model may be configured to generate latent representations of text strings. When a text string is descriptive of an image, respective latent representation of the text string and the image may be similar. When the text string is not descriptive of the image, respective latent representation of the text string and the image may be dissimilar. The contrastive learning process may be improved by, prior to contrastive training, pretraining the first machine learning model and, during the contrastive training, holding its parameters fixed while adjusting parameters of the second machine learning model. Thus, the second machine learning model may be trained to match latent representations generated by the first machine learning model, and the first machine learning model might not need to relearn to generate useful latent representations, thereby simplifying the contrastive learning process and improving the models resulting therefrom. In a first example embodiment, a method may include obtaining (i) a pretrained image encoder and (ii) a training sample that includes a training image and a training text string corresponding to the training image. The method may also include initializing a text encoder in an untrained state. The method may additionally include determining, using the pretrained image encoder and based on the training image, a first latent representation of the training image, and determining, using the text encoder and based on the training text string, a second latent representation of the training text string. The method may further include determining a loss value based on the first latent representation and the second latent representation. The method may yet further include updating, based on the loss value, one or more parameters of the text encoder while holding fixed parameters of the pretrained image encoder, and outputting the text encoder in a trained state. In a second example embodiment, a system may include a processor and a non-transitory computer-readable medium having stored thereon instructions that, when executed by the processor, cause the processor to perform operations in accordance with the first example embodiment. In a third example embodiment, a non-transitory computer-readable medium may have stored thereon instructions that, when executed by a computing device, cause the computing device to perform operations in accordance with the first example embodiment. In a fourth example embodiment, a system may include various means for carrying out each of the operations of the first example embodiment. In a fifth example embodiment, a method may include obtaining an image, a text string, a pretrained image encoder, and a text encoder. The text encoder may be trained by a training process that includes obtaining (i) the pretrained image encoder and (ii) a training sample including a training image and a training text string corresponding to the training image. The training process may also include initializing the text encoder in an untrained state. The training process may additionally include determining, using the pretrained image encoder and based on the training image, a first latent representation of the training image, and determining, using the text encoder and based on the training text string, a second latent representation of the training text string. The training process may further include determining a loss value based on the first latent representation and the second latent representation. The training process may yet further include updating, based on the loss value, one or more parameters of the text encoder while holding fixed parameters of the pretrained image encoder. The method may also include determining, using the pretrained image encoder and based on the image, a third latent representation of the image, and determining, using the text encoder and based on the text string, a fourth latent representation of the text string. The method