Search

EP-4736069-A1 - CONDITIONAL DIFFUSION NEURAL NETWORK DISTILLATION TRAINING

EP4736069A1EP 4736069 A1EP4736069 A1EP 4736069A1EP-4736069-A1

Abstract

Methods, systems, and apparatus, including computer programs encoded on computer storage media, for performing distillation training of a conditional diffusion neural network.

Inventors

  • MEI, Kangfu
  • DELBRACIO, MAURICIO
  • TALEBI, Hossein
  • TU, Zhengzhong
  • MILANFAR, PEYMAN

Assignees

  • Google LLC

Dates

Publication Date
20260506
Application Date
20240923

Claims (20)

  1. 1. A method performed by one or more computers and for training a conditional diffusion neural network having a plurality of network parameters and configured to receive a diffusion input that comprises (i) a latent representation of an image, (ii) data identifying a current time step, and (iii) a conditioning input and to process the diffusion input to generate a denoising output that characterizes predicted noise in the latent representation, the method comprising: initializing online values and target values of the network parameters: and repeatedly performing operations to update the online values, the operations comprising: receiving a set of one or more training examples, each training example comprising a training image and a training conditioning input that characterizes the training image; for each training example: selecting a first time step from a plurality of time steps that are each associated with a respective noise level; selecting, based on the first time step, a second time step that is associated with a lower noise level than the first time step; combining a representation of the training image in the training example with sampled noise in accordance with the noise level associated with the first time step to generate a latent representation for the first time step; processing a first denoising input comprising (i) the latent representation for the first time step, (ii) data identifying the first time step, and (iii) the training conditioning input in the training example using the conditional diffusion neural network and in accordance with the online values of the netw ork parameters to generate a first denoising output; determining, from the first denoising output and in accordance with the respective noise level for the first time step, an estimate of the representation of the training image for the first time step; determining, from the first denoising output and in accordance with the respective noise level for the first time step, an estimate of the sampled noise for the first time step; combining the estimate of the representation of the training image for the first time step with the sampled noise in accordance with the noise level associated with the second time step to generate a latent representation for the second time step; processing a second denoising input comprising (i) the latent representation for the second time step, (ii) data identifying the second time step, and (iii) the conditioning input using the conditional diffusion neural network and in accordance with target values of the network parameters to generate a second denoising output; and determining, from the second denoising output and in accordance with the respective noise level for the second time step, an estimate of the sampled noise for the second time step; training the conditional diffusion neural network to update the online values of the network parameters based on gradients of a loss function that comprises a first term that measures, for each of the one or more training examples, a difference between (i) the estimate of the sampled noise for the second time step for the training example and (ii) the estimate of the sampled noise for the first time step for the training example.
  2. 2. The method of claim 1 , further comprising: after the training: receiving a new conditioning input; and generating a new image conditioning on the new conditioning input using the trained conditional diffusion neural network and in accordance with the online values of the network parameters.
  3. 3. The method of claim 1 or claim 2, wherein the loss function further comprises a second term that measures, for each of the one or more training examples, a difference between (i) the representation of the training example in the training example and (ii) the estimate of the representation of the training image in the training example for the first time step.
  4. 4. The method of claim 3, wherein the difference between (i) the representation of the training example in the training example and (ii) the estimate of the representation of the training image in the training example for the first time step is an L2 distance between (i) the representation of the training example in the training example and (ii) the estimate of the representation of the training image in the training example for the first time step.
  5. 5. The method of any one of claims 1-4, further comprising: after updating the online values, updating the target values of the network parameters based on the online values.
  6. 6. The method of claim 5, wherein the target values are maintained as an exponential moving average of the online values and wherein updating the target values comprises updating the exponential moving average using the online values.
  7. 7. The method of any preceding claim, wherein the same first and second time step are selected for each of the one or more training examples.
  8. 8. The method of any preceding claim, wherein the first time step is sampled from the plurality of time steps.
  9. 9. The method of any preceding claim, wherein the second time step has a predetermined offset within the plurality of time steps relative to the first time step.
  10. 10. The method of any preceding claim, wherein the denoising output is a velocity model output that is prediction of a combination, in accordance with a noise level associated with the time step, of a noise component of the latent representation and a representation of the image.
  11. 11. The method of any preceding claim, wherein initializing the online values and the target values of the network parameters comprises initializing the online values to be equal to the target values.
  12. 12. The method of any preceding claim, wherein initializing the online values and the target values of the network parameters comprises: initializing online values of a first subset of the network parameters to be equal to pretrained values of corresponding network parameters of a pre-trained diffusion neural network.
  13. 13. The method of claim 12, wherein the conditioning input is of a first modality, and wherein the pre-trained diffusion neural network is not configured to receive conditioning inputs of the first modality.
  14. 14. The method of claim 12 or 13, wherein the network parameters comprise the first subset of network parameters, a second subset of network parameters that do not have a corresponding network parameter in the pre-trained diffusion neural network, and a scalar weight value that is assigned to features generated using the second subset of the network parameters in a combination of the features generated using the second subset of the network parameters and features generated using at least some of the first subset of network parameters, and wherein initializing the online values and the target values of the network parameters comprises: setting the scalar weight value to zero.
  15. 15. The method of any one of claims 12-14, training the conditional diffusion neural netw ork to update the online values of the netw ork parameters comprises training the conditional diffusion neural network to update the online values of the network parameters comprises holding values of the network parameters in the first subset fixed at the pre-trained values while updating online values of the network parameters not in the first subset.
  16. 16. The method of any preceding claim, wherein combining a representation of the training image in the training example with sampled noise in accordance with the noise level associated with the first time step to generate a latent representation for the first time step comprises: determining a first weight based on the noise level associated with the first time step; determining a second weight based on the noise level associated with the first time step; and determining a sum of the representation of the training image in the training example weighted by the first weight and the sampled noise weighted by the second weight to generate a latent representation for the first time step.
  17. 17. The method of claim 1 , when dependent on claim 10, wherein determining, from the first denoising output and in accordance with the respective noise level for the first time step, an estimate of the representation of the training image for the first time step comprises: determining a sum of the latent representation for the first time step weighted by the first weight and the first denoising output weighted by a negative of the second weight to generate the estimate of the representation of the training image for the first time step.
  18. 18. The method of claim 16 or claim 17, when dependent on claim 10, wherein determining, from the first denoising output and in accordance with the respective noise level for the first time step, an estimate of the sampled noise for the first time step comprises: determining a sum of the first denoising output for the first time step weighted by the first weight and the latent representation for the first time step weighted by a negative of the second weight to generate the estimate of the sampled noise for the first time step.
  19. 19. The method of any preceding claim, wherein combining the estimate of the representation of the training image for the first time step with the sampled noise in accordance with the noise level associated with the second time step to generate a latent representation for the second time step comprises: determining a third weight based on the noise level associated with the second time step; determining a fourth weight based on the noise level associated with the second time step; and determining a sum of the estimate of the representation of the training image for the first time step weighted by the third weight and the sampled noise weighted by the fourth weight to generate a latent representation for the second time step.
  20. 20. The method of claim 19, when dependent on claim 10, wherein determining, from the second denoising output and in accordance with the respective noise level for the second time step, an estimate of the sampled noise for the second time step comprises: determining a sum of the second denoising output for the first time step weighted by the third weight and the latent representation for the second time step weighted by a negative of the fourth weight to generate the estimate of the sampled noise for the second time step.

Description

CONDITIONAL DIFFUSION NEURAL NETWORK DISTILLATION TRAINING CROSS-REFERENCE TO RELATED APPLICATION This application claims priority to U.S. Provisional Application No. 63/584,852, filed on September 22, 2023. The disclosure of the prior application is considered part of and is incorporated by reference in the disclosure of this application. BACKGROUND This specification relates to generating images using machine learning models. As one example, neural networks are machine learning models that employ one or more layers of nonlinear units to predict an output for a received input. Some neural networks include one or more hidden layers in addition to an output layer. The output of each hidden layer is used as input to another layer in the network, e.g., the next hidden layer or the output layer. Each layer of the network generates an output from a received input in accordance w ith current values of a respective set of weights. SUMMARY This specification describes a system implemented as one or more computer programs on one or more computers that trains a conditional diffusion neural network that can be used to generate an output image conditioned on a conditioning input. Particular embodiments of the subject matter described in this specification can be implemented so as to realize one or more of the following advantages. Generative diffusion models provide strong priors for image generation, e.g., text-to- image generation and serve as a foundation for conditional generation tasks such as image editing, restoration, super-resolution, and compositing. However, one major limitation of diffusion models is their slow sampling time. That is, during the reverse diffusion process, each sampling iteration requires sampling a denoising output from the diffusion model, which is a time and computationally-intensive task. Moreover, generating high quality outputs using existing techniques can require a large number of sampling iterations. To address this challenge, this specification describes a conditional distillation method for training a conditional diffusion neural network that, after training, allows for conditional sampling with very few steps. In other words, after being trained using the described techniques, the conditional diffusion neural network can be used to generate high- quality images in a small number of sampling iterations, significantly improving the computational efficiency of the generation process relative to other approaches. For example, the techniques can be used to directly distill a pre-trained diffusion neural network in a single stage through j oint-1 earning, largely simplifying previous two- stage procedures that include both distillation and then conditional fine-tuning. This makes the described “single stage7’ distillation approach that both (i) decreases the number of sampling iterations required and (ii) adapts the neural network to a new type of conditioning input significantly more computationally efficient than other approaches that separately “distill’’ and then “adapt’’ (or vice versa). Furthermore, this specification describes a new parameter-efficient distillation mechanism that distills an additional conditional image generation task with only a small number of additional parameters combined with a shared frozen unconditional backbone. The described approach outperforms existing distillation approaches across multiple tasks, including super-resolution, image editing, and conditioned generation, in terms of output image quality given the same sampling time. Additionally, the described techniques result in trained models that can match or exceed the performance of much slower fine-tuned conditional diffusion models. The details of one or more embodiments of the subject matter of this specification are set forth in the accompanying drawings and the description below. Other features, aspects, and advantages of the subject matter will become apparent from the description, the drawings, and the claims. BRIEF DESCRIPTION OF THE DRAWINGS FIG. 1 is a diagram of an example training system. FIG. 2 is a flow diagram of an example process for training the diffusion neural network. FIG. 3 is a flow diagram of an example process for performing a training iteration. FIG. 4 shows an example of the operation of the system. FIG. 5 shows an example of the performance of the described techniques. Like reference numbers and designations in the various draw ings indicate like elements. DETAILED DESCRIPTION FIG. 1 is a diagram of an example training system 100. The training system 100 is an example of a system implemented as computer programs on one or more computers in one or more locations, in which the systems, components, and techniques described below can be implemented. The system 100 trains a conditional diffusion neural network 110 that can be used to generate an output image 112 conditioned on a conditioning input 102. That is, the system 100 can generate an output image using the condi