Search

US-12626691-B1 - Language model hallucination mitigation using contrastive decoding

US12626691B1US 12626691 B1US12626691 B1US 12626691B1US-12626691-B1

Abstract

Devices and techniques are generally described to mitigate hallucination for language models (LMs) using contrastive decoding. In various examples, first context data and first adversarial data may be determined based on a natural language input. First prompt including the natural language input, second prompt including the natural language input and the first context data, and third prompt including the natural language input and the first adversarial data may be generated. An LM may generate first vector of logits for the first prompt, second vector of logits for the second prompt, and third vector of logits for the third prompt. A decoder may generate first output data based on a combination of the first vector of logits, the second vector of logits, and the third vector of logits.

Inventors

  • Emilio Fabio Monti
  • Jens Lehmann
  • Hitham Ahmed Assem Aly Salama
  • Zheng Zhao

Assignees

  • AMAZON TECHNOLOGIES, INC.

Dates

Publication Date
20260512
Application Date
20231214

Claims (20)

  1. 1 . A computer-implemented method comprising: receiving a first natural language input; determining, using a first context retrieval component, first context data determined to be relevant to the first natural language input; generating first adversarial data determined to be irrelevant to the first natural language input; generating first prompt data comprising the first natural language input; generating second prompt data comprising the first natural language input and the first context data; generating third prompt data comprising the first natural language input and the first adversarial data; generating, by a first language model (LM) using the first prompt data, a first vector of logits representing a first output of the first LM generated using parametric knowledge of the first LM learned during training; generating, by the first LM using the second prompt data, a second vector of logits representing a second output of the first LM; generating, by the first LM using the third prompt data, a third vector of logits representing a third output of the first LM; predicting, by a contrastive decoder, a first output token by combining the first vector of logits, the second vector of logits, and the third vector of logits in a weighted combination; and generating first output data comprising the first output token as a natural language output.
  2. 2 . The computer-implemented method of claim 1 , further comprising: determining a difference vector representing a difference between the second vector of logits and the third vector of logits; generating a weighted difference by multiplying the difference vector by a scalar hyperparameter a; generating a combined vector of logits by adding the first vector of logits to the weighted difference; and determining, by the contrastive decoder, the first output token using the combined vector of logits.
  3. 3 . The computer-implemented method of claim 1 , further comprising: determining a first confidence value of the first context retrieval component associated with the second vector of logits; and determining a scalar hyperparameter a used to generate the weighted combination that is proportional to the first confidence value.
  4. 4 . The computer-implemented method of claim 1 , further comprising: determining, by the contrastive decoder, a second output token using the first vector of logits; generating second output data comprising the second output token; determining, by the contrastive decoder, a third output token using the second vector of logits; generating third output data comprising the third output token; receiving first evaluation data indicating that the third output data represents a more relevant response to the first natural language input relative to the second output data; and updating parameters of the first LM based at least in part on the first evaluation data.
  5. 5 . A method comprising: receiving a first natural language input; generating, by a first language model (LM), a first vector of logits based on first prompt data comprising the first natural language input; generating, by the first LM, a second vector of logits based on second prompt data comprising the first natural language input and first context data associated with the first natural language input; generating, by the first LM, a third vector of logits based on third prompt data comprising the first natural language input and first adversarial data; and generating, by a contrastive decoder, first output data based on a combination of the first vector of logits, the second vector of logits, and the third vector of logits, wherein the first output data is a natural language output.
  6. 6 . The method of claim 5 , further comprising: generating, by the first LM, a first application programming interface (API) call to a first retrieval component, wherein the first API call comprises a representation of the first natural language input; and receiving, based on the first API call, the first context data.
  7. 7 . The method of claim 5 , further comprising: generating first embedding data representing the first context data; and determining second embedding data representing the first adversarial data based at least in part on a first distance in an embedding space between the first embedding data and the second embedding data.
  8. 8 . The method of claim 5 , further comprising sampling the first adversarial data from among text data that is irrelevant to the first natural language input.
  9. 9 . The method of claim 5 , further comprising: determining difference data representing a difference between the second vector of logits and the third vector of logits; and determining a first output token based at least in part on a combination of the difference data and the first vector of logits.
  10. 10 . The method of claim 9 , further comprising: determining a first weight value for the decoder; and determining the difference data by multiplying the first weight value by the difference between the second vector of logits and the third vector of logits.
  11. 11 . The method of claim 5 , further comprising: determining a first score for the first context data, the first score representing a predicted relevance of the first context data to the first natural language input; determining a first weight parameter based at least in part on the first score for the first context data; and determining the combination of the first vector of logits, the second vector of logits, and the third vector of logits using the first weight parameter.
  12. 12 . The method of claim 5 , further comprising: generating, by the decoder, second output data using the first vector of logits; generating, by the decoder, third output data using the second vector of logits; receiving first evaluation data indicating that the third output data represents a more relevant response to the first natural language input relative to the second output data; and updating parameters of the decoder based at least in part on the first evaluation data.
  13. 13 . A system comprising: at least one processor; and non-transitory computer-readable memory storing instructions that, when executed by the at least one processor, are effective to: receive a first natural language input; generate, by a first language model (LM), a first vector of logits based on first prompt data comprising the first natural language input; generate, by the first LM, a second vector of logits based on second prompt data comprising the first natural language input and first context data associated with the first natural language input; generate, by the first LM, a third vector of logits based on third prompt data comprising the first natural language input and first adversarial data; and generate, by a contrastive decoder, first output data based on a combination of the first vector of logits, the second vector of logits, and the third vector of logits, wherein the first output data is a natural language output.
  14. 14 . The system of claim 13 , the non-transitory computer-readable memory storing further instructions that, when executed by the at least one processor, are further effective to: generate, by the first LM, a first application programming interface (API) call to a first retrieval component, wherein the first API call comprises a representation of the first natural language input; and receive, based on the first API call, the first context data.
  15. 15 . The system of claim 13 , the non-transitory computer-readable memory storing further instructions that, when executed by the at least one processor, are further effective to: generate first embedding data representing the first context data; and determine second embedding data representing the first adversarial data based at least in part on a first distance in an embedding space between the first embedding data and the second embedding data.
  16. 16 . The system of claim 13 , the non-transitory computer-readable memory storing further instructions that, when executed by the at least one processor, are further effective to: sample the first adversarial data from among text data that is irrelevant to the first natural language input.
  17. 17 . The system of claim 13 , the non-transitory computer-readable memory storing further instructions that, when executed by the at least one processor, are further effective to: determine difference data representing a difference between the second vector of logits and the third vector of logits; and determine a first output token based at least in part on a combination of the difference data and the first vector of logits.
  18. 18 . The system of claim 17 , the non-transitory computer-readable memory storing further instructions that, when executed by the at least one processor, are further effective to: determine a first weight value for the decoder; and determine the difference data by multiplying the first weight value by a difference between the second vector of logits and the third vector of logits.
  19. 19 . The system of claim 13 , the non-transitory computer-readable memory storing further instructions that, when executed by the at least one processor, are further effective to: determine a first score for the first context data, the first score representing a predicted relevance of the first context data to the first natural language input; determine a first weight parameter based at least in part on the first score for the first context data; and determine the combination of the first vector of logits, the second vector of logits, and the third vector of logits using the first weight parameter.
  20. 20 . The system of claim 13 , the non-transitory computer-readable memory storing further instructions that, when executed by the at least one processor, are further effective to: generate, by the decoder, second output data using the first vector of logits; generate, by the decoder, third output data using the second vector of logits; receive first evaluation data indicating that the third output data represents a more relevant response to the first natural language input relative to the second output data; and update parameters of the decoder based at least in part on the first evaluation data.

Description

BACKGROUND People can interact with computing devices using spoken commands. In some systems, a “wakeword” is used to activate functionality. Natural language processing is used to transform the spoken requests that follow into a computer directive for performing a task. SUMMARY Devices and techniques are generally described to mitigate hallucination for language models (LMs) using contrastive decoding. In various examples, first context data and first adversarial data may be determined based on a natural language input. First prompt including the natural language input, second prompt including the natural language input and the first context data, and third prompt including the natural language input and the first adversarial data may be generated. An LM may generate first vector of logits for the first prompt, second vector of logits for the second prompt, and third vector of logits for the third prompt. A decoder may generate first output data based on a combination of the first vector of logits, the second vector of logits, and the third vector of logits. BRIEF DESCRIPTION OF DRAWINGS FIG. 1 is a flow diagram illustrating an example system for contrastive decoding for a language model (LM), in accordance with various aspects of the present disclosure. FIG. 2 depicts an example LLM-based natural language processing flow, in accordance with various aspects of the present disclosure. FIG. 3A depicts an example of dynamic determination of a hyperparameter that may be used to modulate LM attention to parametric knowledge, in accordance with various aspects of the present disclosure. FIG. 3B depicts an example technique for training a contrastive decoder to dynamically modulate attention to parametric knowledge for a given input, in accordance with various aspects of the present disclosure. FIG. 4 is a block diagram showing an example architecture of a network-connected device that may be used in accordance with various embodiments described herein. FIG. 5 is a block diagram showing an example architecture of a computing device that may be used in accordance with various embodiments described herein. FIG. 6 is a flow chart illustrating an example process for LM inference using contrastive decoding, in accordance with embodiments of the present disclosure. FIG. 7 is a conceptual diagram illustrating components that may be included in a device, according to embodiments of the present disclosure. DETAILED DESCRIPTION In the following description, reference is made to the accompanying drawings that illustrate several examples of the present invention. It is understood that other examples may be utilized and various operational changes may be made without departing from the scope of the present disclosure. The following detailed description is not to be taken in a limiting sense, and the scope of the embodiments of the present invention is defined only by the claims of the issued patent. Devices with integrated processing capabilities are often configured with network communication capability and/or other computing functions allowing the devices to send data to and/or receive data from other devices. In some examples, such devices may include voice-enabled personal assistants and/or other natural language processing interfaces that may be used to control the devices, answer questions, communicate with other people/devices, and/or otherwise interact with the devices and/or other devices. As such devices become more and more prevalent in both the home, office, public spaces, quasi-public spaces (e.g., hotels, offices, retail spaces), and elsewhere generally, and as the technology matures, new services and features are being developed. For instance, in some cases devices may be paired or otherwise grouped together with one another to enable certain functionality. For example, a device that includes voice-based personal assistant functionality may be paired with a device including a display so that spoken commands may be used to control content output by the display device. In another example, content may be transferred from one device to another device in response to user requests and/or other triggering events (e.g., predefined user routines of actions, presence information, etc.). Some natural language processing flows may employ one or more language models (LMs, such as large language models (LLMs)) in order to process natural language requests. An LLM is an artificial intelligence (AI) model that may be capable of processing and generating human-like text based on the latent information it has learned from vast amounts of training data. The term “large” refers to the size of these models in terms of the number of parameters or weights, which are the values that the model learns during training to make predictions and generate text. LLMs may have millions, billions (or even more) parameters, which enable such models to capture complex patterns and nuances in language that, in turn, allow the models to understand a