US-12619922-B2 - Training embedding models using a stale embedding cache for negative sampling
Abstract
Provided are systems and methods which more efficiency train embedding models through the use of a cache of item embeddings for candidate items over a number of training iterations. The cached item embeddings can be “stale” embeddings that were generated by a previous version of the model at a previous training iteration. Specifically, at each iteration, the (potentially stale) item embeddings included in the cache can be used when generating similarity scores that are the basis for sampling a number of items to use as negatives in the current training iteration. For example, a Gumbel-Max sampling approach can be used to sample negative items that will enable an approximation of a true gradient. New embeddings can be generated for the sampled negative items and can be used to train the model at the current iteration.
Inventors
- Erik Michael Lindgren
- Sashank Jakkam Reddi
- Ruiqi Guo
- Sanjiv Kumar
Assignees
- GOOGLE LLC
Dates
- Publication Date
- 20260505
- Application Date
- 20221108
Claims (20)
- 1 . A computer-implemented method for training embedding models with improved efficiency, the method comprising: for a plurality of training iterations: processing, by a computing system, a query with an embedding model to generate a query embedding for the query; accessing, by the computing system, an embedding table that stores a plurality of item embeddings for at least a portion of a plurality of candidate items, wherein the item embedding for at least one of the plurality of candidate items was generated by a previous version of the embedding model in one or more previous iterations; generating, by the computing system, a plurality of similarity scores for the query embedding with respect to at least a portion of the plurality of item embeddings included in the embedding table; sampling, by the computing system, from the plurality of candidate items based at least in part on the plurality of similarity scores to select one or more sampled items; processing, by the computing system, the one or more sampled items with the embedding model to respectively generate one or more sampled item embeddings; generating, by the computing system, one or more similarity scores for the query embedding with respect to the one or more sampled item embeddings; and updating, by the computing system, one or more values of one or more parameters of the embedding model based at least in part on the similarity scores generated for the query embedding with respect to at least a portion of the sampled item embeddings; after the plurality of training iterations, providing, by the computing system, the embedding model as an output.
- 2 . The computer-implemented method of claim 1 , wherein sampling, by the computing system, from the plurality of candidate items based at least in part on the plurality of similarity scores comprises performing, by the computing system, a Gumbel-Max sampling technique to sample from the plurality of candidate items based at least in part on the plurality of similarity scores.
- 3 . The computer-implemented method of claim 1 , wherein the embedding table is stored in a memory portion of a hardware accelerator.
- 4 . The computer-implemented method of claim 1 , further comprising: prior to processing the query with the embedding model: obtaining, by the computing system, a training example from a training dataset, wherein the training example comprises the query and one or more positive items labeled as positive results for the query, the one or more positive items being a subset of the plurality of candidate items; and prior to generating the plurality of similarity scores: processing, by the computing system, the one or more positive items with the embedding model to respectively generate one or more positive embeddings for the one or more positive items; and updating, by the computing system, the embedding table to include the one or more positive embeddings for the one or more positive items of the plurality of candidate items.
- 5 . The computer-implemented method of claim 4 , wherein updating, by the computing system, the one or more values of the one or more parameters of the embedding model comprises updating, by the computing system, the one or more values of the one or more parameters of the embedding model based at least in part on the similarity scores generated for the query embedding with respect to at least a portion of the positive items and at least the portion of the sampled items.
- 6 . The computer-implemented method of claim 1 , wherein updating, by the computing system, the one or more values of the one or more parameters of the embedding model comprises: determining, by the computing system, an approximate gradient of a cross-entropy loss based on the similarity scores generated for the query embedding with respect to at least a portion of the sampled item embeddings; and updating, by the computing system, the one or more values of the one or more parameters of the embedding model based at least in part on the approximate gradient of the cross-entropy loss.
- 7 . The computer-implemented method of claim 1 , wherein the embedding table comprises a full document cache that stores item embeddings for all of the plurality of candidate items.
- 8 . The computer-implemented method of claim 1 , further comprising, for each iteration: removing, by the computing system, from the embedding table the item embeddings associated with a fraction of the candidate items; and replacing, by the computing system, the item embeddings that were removed from the embedding table with new item embeddings generated for the fraction of the candidate items.
- 9 . The computer-implemented method of claim 1 , wherein the embedding table comprises a streaming cache that stores item embeddings for fewer than all of the plurality of candidate items.
- 10 . The computer-implemented method of claim 9 , further comprising, for each iteration: removing, by the computing system, from the embedding table the item embeddings associated with a fraction of the candidate items; sampling, by the computing system, newly sampled items from the plurality of candidate items; and replacing, by the computing system, the item embeddings that were removed from the embedding table with new item embeddings generated for newly sampled items.
- 11 . The computer-implemented method of claim 1 , wherein the embedding model comprises a two-tower dual encoding model that comprises a query encoder and an item encoder.
- 12 . The computer-implemented method of claim 1 , wherein the plurality of candidate items comprise: images; textual documents; web documents; products; videos; or entities.
- 13 . The computer-implemented method of claim 1 , wherein the query comprises: a textual query; a voice query; or an image query.
- 14 . A computing system for training embedding models with improved efficiency, the computing system comprising one or more processors and one or more non-transitory computer-readable media that store instructions that, when executed by the one or more processors, cause the computing system to perform operations, the operations comprising: for a plurality of training iterations: processing, by the computing system, a query with an embedding model to generate a query embedding for the query; accessing, by the computing system, an embedding table that stores a plurality of item embeddings for at least a portion of a plurality of candidate items, wherein the item embedding for at least one of the plurality of candidate items was generated by a previous version of the embedding model in one or more previous iterations; generating, by the computing system, a plurality of similarity scores for the query embedding with respect to at least a portion of the plurality of item embeddings included in the embedding table; sampling, by the computing system, from the plurality of candidate items based at least in part on the plurality of similarity scores to select one or more sampled items; processing, by the computing system, the one or more sampled items with the embedding model to respectively generate one or more sampled item embeddings; generating, by the computing system, one or more similarity scores for the query embedding with respect to the one or more sampled item embeddings; and updating, by the computing system, one or more values of one or more parameters of the embedding model based at least in part on the similarity scores generated for the query embedding with respect to at least a portion of the sampled item embeddings; after the plurality of training iterations, providing, by the computing system, the embedding model as an output.
- 15 . The computing system of claim 14 , wherein sampling, by the computing system, from the plurality of candidate items based at least in part on the plurality of similarity scores comprises performing, by the computing system, a Gumbel-Max sampling technique to sample from the plurality of candidate items based at least in part on the plurality of similarity scores.
- 16 . The computing system of claim 14 , wherein the embedding table is stored in a memory portion of a hardware accelerator.
- 17 . The computing system of claim 14 , wherein the operations further comprise: prior to processing the query with the embedding model: obtaining, by the computing system, a training example from a training dataset, wherein the training example comprises the query and one or more positive items labeled as positive results for the query, the one or more positive items being a subset of the plurality of candidate items; and prior to generating the plurality of similarity scores: processing, by the computing system, the one or more positive items with the embedding model to respectively generate one or more positive embeddings for the one or more positive items; and updating, by the computing system, the embedding table to include the one or more positive embeddings for the one or more positive items of the plurality of candidate items.
- 18 . One or more non-transitory computer-readable media that store an embedding model that has been trained by performance of operations, the operations comprising: for a plurality of training iterations: processing, by a computing system, a query with an embedding model to generate a query embedding for the query; accessing, by the computing system, an embedding table that stores a plurality of item embeddings for at least a portion of a plurality of candidate items, wherein the item embedding for at least one of the plurality of candidate items was generated by a previous version of the embedding model in one or more previous iterations; generating, by the computing system, a plurality of similarity scores for the query embedding with respect to at least a portion of the plurality of item embeddings included in the embedding table; sampling, by the computing system, from the plurality of candidate items based at least in part on the plurality of similarity scores to select one or more sampled items; processing, by the computing system, the one or more sampled items with the embedding model to respectively generate one or more sampled item embeddings; generating, by the computing system, one or more similarity scores for the query embedding with respect to the one or more sampled item embeddings; and updating, by the computing system, one or more values of one or more parameters of the embedding model based at least in part on the similarity scores generated for the query embedding with respect to at least a portion of the sampled item embeddings.
- 19 . The one or more non-transitory computer-readable media of claim 18 , wherein the embedding table comprises a full document cache that stores item embeddings for all of the plurality of candidate items.
- 20 . The one or more non-transitory computer-readable media of claim 18 , wherein the embedding table comprises a streaming cache that stores item embeddings for fewer than all of the plurality of candidate items.
Description
RELATED APPLICATIONS This application claims priority to and the benefit of U.S. Provisional Patent Application No. 63/277,385, filed Nov. 9, 2021. U.S. Provisional Patent Application No. 63/277,385 is hereby incorporated by reference in its entirety. FIELD The present disclosure relates generally to machine learning. More particularly, the present disclosure relates to efficient training of embedding models (e.g., embedding-based retrieval models) using a negative cache. BACKGROUND Learning to represent objects as numerical vectors (e.g., dense vectors), often called “embeddings”, has proved to be crucial in large scale information retrieval tasks from multiple domains including, among other domains, vision and natural language processing. A popular paradigm for such learning tasks involves training two separate neural networks (often called two-towers or dual-encoders), each representing a query and a document. Given positive and negative (query, document) pairs, the learning task trains the two networks by minimizing a loss function, usually softmax cross-entropy, to encourage positive pairs to have higher similarity scores and negative pairs to have lower scores. While it is easy to sample positive pairs of examples through user feedback such as impressions, clicks, or other forms of inferred approval, it is more challenging to sample good negative pairs from a pool of potentially millions or even billions of documents. A large number of negative pairs is often required to ensure high quality of the final model, which makes the training process consume significant computational resources. A number of strategies have been proposed in the literature to address the problem of sampling good negative pairs from a large corpus. The most common approach is to use in-batch negatives, which treats random, non-positive pairs in a minibatch as negatives. This approach is computationally efficient and works in a streaming setting, but the pool of negative examples is limited to the minibatch. Towards the later stages of the training, the in-batch negatives become less informative (i.e., have low gradients) since they are sampled randomly without paying attention to which negatives are hard for a given query. Another popular approach is to maintain an asynchronous retrieval index of the full dataset for negative sampling. Negatives from the full dataset can be extracted based on approximate retrieval techniques such as ScaNN, Faiss, or SPTAG. However, it requires coordinating with a separate process for re-indexing and re-building the retrieval index at each learning iteration, which is not only computationally expensive and hard to maintain but also suffers from the problem of stale index. SUMMARY Aspects and advantages of embodiments of the present disclosure will be set forth in part in the following description, or can be learned from the description, or can be learned through practice of the embodiments. One example aspect is directed to a computer-implemented method for training embedding models with improved efficiency, the method comprising: for a plurality of training iterations: processing, by a computing system, a query with an embedding model to generate a query embedding for the query; accessing, by the computing system, an embedding table that stores a plurality of item embeddings for at least a portion of a plurality of candidate items, wherein the item embedding for at least one of the plurality of candidate items was generated by a previous version of the embedding model in one or more previous iterations; generating, by the computing system, a plurality of similarity scores for the query embedding with respect to at least a portion of the plurality of item embeddings included in the embedding table; sampling, by the computing system, from the plurality of candidate items based at least in part on the plurality of similarity scores to select one or more sampled items; processing, by the computing system, the one or more sampled items with the embedding model to respectively generate one or more sampled item embeddings; generating, by the computing system, one or more similarity scores for the query embedding with respect to the one or more sampled item embeddings; and updating, by the computing system, one or more values of one or more parameters of the embedding model based at least in part on the similarity scores generated for the query embedding with respect to at least a portion of the sampled item embeddings; and after the plurality of training iterations, providing, by the computing system, the embedding model as an output. In some implementations, sampling, by the computing system, from the plurality of candidate items based at least in part on the plurality of similarity scores comprises performing, by the computing system, a Gumbel-Max sampling technique to sample from the plurality of candidate items based at least in part on the plurality of similarity scores. In some implementations, th