Search

US-12619880-B2 - Methods, devices and media for re-weighting to improve knowledge distillation

US12619880B2US 12619880 B2US12619880 B2US 12619880B2US-12619880-B2

Abstract

Methods, devices and processor-readable media for re-weighting to improve knowledge distillation are described. A reweighting module may be used to determine relative weights to assign to a ground truth label and dark knowledge distilled from the teacher (i.e. the teacher output logits used as soft labels). A meta-reweighting method is described to optimize the weights for a given labeled data sample.

Inventors

  • Peng Lu
  • Ahmad RASHID
  • Mehdi Rezagholizadeh
  • Abbas GHADDAR

Assignees

  • Peng Lu
  • Ahmad RASHID
  • Mehdi Rezagholizadeh
  • Abbas GHADDAR

Dates

Publication Date
20260505
Application Date
20210415

Claims (20)

  1. 1 . A method for knowledge distillation, comprising: obtaining a batch of training data comprising one or more labeled training data samples, each labeled training data sample having a respective ground truth label; processing the batch of training data, using a student model comprising a plurality of learnable parameters, to generate, for input data in each data sample in the batch of training data, a student prediction; for each labeled training data sample in the batch of training data, processing the student prediction and the ground truth label to compute a respective ground truth loss; processing the batch of training data, using a trained teacher model, to generate, for each labeled training data sample in the batch of training data, a teacher prediction; for each labeled data sample in the batch of training data, processing the student prediction and the teacher prediction to compute a respective knowledge distillation loss; determining a weighted loss that is a weighted sum of the knowledge distillation loss and ground truth loss for each labeled training data sample in the batch of training data, the determining including: determining a knowledge distillation weight for the labeled training data sample, the knowledge distillation weight being used to weight the knowledge distillation loss representing a difference between the student prediction and the teacher prediction for the labeled training data sample; determining a ground truth weight for the labeled training data sample, the ground truth weight being used to weight the ground truth loss representing a difference between the student prediction and the ground truth label for the labeled training data sample; and computing the weighted loss as the sum of: the knowledge distillation loss weighted by the knowledge distillation weight; and the ground truth loss weighted by the ground truth weight; performing gradient descent on the student model using the weighted loss to identify an adjusted set of values for the plurality of learnable parameters of the student; and adjusting the values of the plurality of learnable parameters of the student to the adjusted set of values.
  2. 2 . The method of claim 1 , wherein the knowledge distillation weight and ground truth weight are determined based on user input.
  3. 3 . The method of claim 1 , wherein the knowledge distillation weight and ground truth weight are determined by a meta reweighting process.
  4. 4 . The method of claim 3 , wherein the meta reweighting process comprises: for each respective learnable parameter of the plurality of learnable parameters of the student, determining an optimized value of the respective learnable parameter as a function of a knowledge distillation perturbation variable and a ground truth perturbation variable with respect to the batch of training data; determining, for each labeled training data sample in the batch of training data, a respective estimated optimized value of the knowledge distillation perturbation variable and a respective estimated optimized value of the ground truth perturbation variable with respect to a batch of validation data; and for each labeled training data sample in the batch of training data, using the respective estimated optimized value of the knowledge distillation perturbation variable as the knowledge distillation weight and the respective estimated optimized value of the ground truth perturbation variable as the ground truth weight.
  5. 5 . The method of claim 4 , wherein: determining the optimized value of each respective learnable parameter as a function of the knowledge distillation weight and the ground truth weight comprises: generating a meta student model having a plurality of learnable parameters with values equal to the values of the plurality of learnable parameters of the student model; processing the batch of training data, using the meta student model, to generate, for each labeled training data sample in the batch of training data, a meta student prediction; for each labeled training data sample in the batch of training data, processing the meta student prediction and the ground truth label to compute a respective meta ground truth loss; for each labeled training data sample in the batch of training data, processing the meta student prediction and the teacher prediction to compute a respective meta knowledge distillation loss; determining a perturbed loss as the sum of: the meta knowledge distillation loss weighted by the knowledge distillation perturbation variable; and the meta ground truth loss weighted by the ground truth perturbation variable; performing gradient descent on the meta student model using the perturbed loss to identify an optimal set of values of the plurality of learnable parameters of the meta student model, such that each value in the optimal set of values is defined as a function of the knowledge distillation perturbation variable and the ground truth perturbation variable; and adjusting the values of the plurality of learnable parameters of the meta student model to the optimal set of values based on a predetermined value of the knowledge distillation perturbation variable and a predetermined value of the ground truth perturbation variable, thereby generating an adjusted meta student model.
  6. 6 . The method of claim 5 , wherein: the predetermined value of the knowledge distillation perturbation variable is zero; and the predetermined value of the ground truth perturbation variable is zero, and wherein adjusting the values of the plurality of learnable parameters of the meta student model comprises leaving the values of the plurality of learnable parameters of the meta student model unchanged.
  7. 7 . The method of claim 5 , wherein: the batch of validation data comprises one or more labeled validation data samples, each labeled validation data sample having a respective ground truth label; and determining, for each data sample in the batch of training data, a respective estimated optimized value of the knowledge distillation perturbation variable and a respective estimated optimized value of the ground truth perturbation variable with respect to a second batch of data comprises: processing the batch of validation data, using the adjusted meta student model, to generate, for each labeled validation data sample in the batch of validation data, a meta student prediction; for each labeled validation data sample in the batch of validation data, processing the adjusted meta student prediction and the ground truth label to compute a respective meta ground truth loss; processing the batch of validation data, using the teacher model, to generate, for each labeled validation data sample in the batch of validation data, a teacher validation prediction; for each labeled validation data sample in the batch of training data, processing the adjusted meta student prediction and the teacher validation prediction to compute a respective meta knowledge distillation loss; determining a validation loss based on the meta knowledge distillation loss and the meta ground truth loss for each labeled validation data sample in the batch of validation data; for each labeled training data sample in the batch of training data: computing a gradient of the validation loss with respect to the knowledge distillation perturbation variable and the ground truth perturbation variable by computing a gradient of the validation loss with respect to the optimal set of values of the plurality of learnable parameters of the meta student model, each learnable parameter value of the optimal set of values being defined as a function of the knowledge distillation perturbation variable and the ground truth perturbation variable; performing gradient descent to compute: the estimated optimized value of the knowledge distillation perturbation variable; and the estimated optimized value of the ground truth perturbation variable.
  8. 8 . The method of claim 7 , further comprising: obtaining one or more additional batches of training data; obtaining, for each additional batch of training data, an additional batch of validation data; and repeating, for each additional batch of training data, generating the student predictions, computing the ground truth losses, generating the teacher predictions, computing the knowledge distillation losses, determining the weighted loss, identifying the adjusted set of values of the plurality of learnable parameters of the student, and adjusting the values of the plurality of learnable parameters of the student.
  9. 9 . The method of claim 1 , wherein: the ground truth loss comprises a cross entropy loss; and the trained teacher model is trained to perform a natural language processing binary classification task on training data wherein each labeled training data sample comprises input data comprising a plurality of text tokens.
  10. 10 . A device, comprising: a processor; and a memory having stored thereon instructions which, when executed by the processor, cause the device to: obtain a batch of training data comprising one or more labeled training data samples, each labeled training data sample having a respective ground truth label; process the batch of training data, using a student model comprising a plurality of learnable parameters, to generate, for input data in each data sample in the batch of training data, a student prediction; for each labeled training data sample in the batch of training data, process the student prediction and the ground truth label to compute a respective ground truth loss; process the batch of training data, using a trained teacher model, to generate, for each labeled training data sample in the batch of training data, a teacher prediction; for each labeled data sample in the batch of training data, process the student prediction and the teacher prediction to compute a respective knowledge distillation loss; determine a weighted loss based on that is a weighted sum of the knowledge distillation loss and ground truth loss for each labeled training data sample in the batch of training data by: determining a knowledge distillation weight for the labeled training data sample, the knowledge distillation weight being used to weight the knowledge distillation loss representing a difference between the student prediction and the teacher prediction for the labeled training data sample; determining a ground truth weight for the labeled training data sample, the ground truth weight being used to weight the ground truth loss representing a difference between the student prediction and the ground truth label for the labeled training data sample; and computing the weighted loss as the sum of: the knowledge distillation loss weighted by the knowledge distillation weight; and the ground truth loss weighted by the ground truth weight; perform gradient descent on the student model using the weighted loss to identify an adjusted set of values for the plurality of learnable parameters of the student; and adjust the values of the plurality of learnable parameters of the student to the adjusted set of values.
  11. 11 . The device of claim 10 , wherein the knowledge distillation weight and ground truth weight are determined by a meta reweighting process comprising: for each respective learnable parameter of the plurality of learnable parameters of the student, determining an optimized value of the respective learnable parameter as a function of a knowledge distillation perturbation variable and a ground truth perturbation variable with respect to the batch of training data; determining, for each labeled training data sample in the batch of training data, a respective estimated optimized value of the knowledge distillation perturbation variable and a respective estimated optimized value of the ground truth perturbation variable with respect to a batch of validation data; and for each labeled training data sample in the batch of training data, using the respective estimated optimized value of the knowledge distillation perturbation variable as the knowledge distillation weight and the respective estimated optimized value of the ground truth perturbation variable as the ground truth weight.
  12. 12 . The device of claim 11 , wherein: determining the optimized value of each respective learnable parameter as a function of the knowledge distillation weight and the ground truth weight comprises: generating a meta student model having a plurality of learnable parameters with values equal to the values of the plurality of learnable parameters of the student model; processing the batch of training data, using the meta student model, to generate, for each labeled training data sample in the batch of training data, a meta student prediction; for each labeled training data sample in the batch of training data, processing the meta student prediction and the ground truth label to compute a respective meta ground truth loss; for each labeled training data sample in the batch of training data, processing the meta student prediction and the teacher prediction to compute a respective meta knowledge distillation loss; determining a perturbed loss as the sum of: the meta knowledge distillation loss weighted by the knowledge distillation perturbation variable; and the meta ground truth loss weighted by the ground truth perturbation variable; performing gradient descent on the meta student model using the perturbed loss to identify an optimal set of values of the plurality of learnable parameters of the meta student model, such that each value in the optimal set of values is defined as a function of the knowledge distillation perturbation variable and the ground truth perturbation variable; and adjusting the values of the plurality of learnable parameters of the meta student model to the optimal set of values based on a predetermined value of the knowledge distillation perturbation variable and a predetermined value of the ground truth perturbation variable, thereby generating an adjusted meta student model.
  13. 13 . The device of claim 12 , wherein: the predetermined value of the knowledge distillation perturbation variable is zero; and the predetermined value of the ground truth perturbation variable is zero, and wherein adjusting the values of the plurality of learnable parameters of the meta student model comprises leaving the values of the plurality of learnable parameters of the meta student model unchanged.
  14. 14 . The device of claim 12 , wherein: the batch of validation data comprises one or more labeled validation data samples, each labeled validation data sample having a respective ground truth label; and determining, for each data sample in the batch of training data, a respective estimated optimized value of the knowledge distillation perturbation variable and a respective estimated optimized value of the ground truth perturbation variable with respect to a second batch of data comprises: processing the batch of validation data, using the adjusted meta student model, to generate, for each labeled validation data sample in the batch of validation data, a meta student prediction; for each labeled validation data sample in the batch of validation data, processing the adjusted meta student prediction and the ground truth label to compute a respective meta ground truth loss; processing the batch of validation data, using the teacher model, to generate, for each labeled validation data sample in the batch of validation data, a teacher validation prediction; for each labeled validation data sample in the batch of training data, processing the adjusted meta student prediction and the teacher validation prediction to compute a respective meta knowledge distillation loss; determining a validation loss based on the meta knowledge distillation loss and the meta ground truth loss for each labeled validation data sample in the batch of validation data; for each labeled training data sample in the batch of training data: computing a gradient of the validation loss with respect to the knowledge distillation perturbation variable and the ground truth perturbation variable by computing a gradient of the validation loss with respect to the optimal set of values of the plurality of learnable parameters of the meta student model, each learnable parameter value of the optimal set of values being defined as a function of the knowledge distillation perturbation variable and the ground truth perturbation variable; performing gradient descent to compute: the estimated optimized value of the knowledge distillation perturbation variable; and the estimated optimized value of the ground truth perturbation variable.
  15. 15 . The device of claim 14 , wherein the instructions, when executed by the processor, further cause the device to: obtain one or more additional batches of training data; obtain, for each additional batch of training data, an additional batch of validation data; and repeat, for each additional batch of training data, generating the student predictions, computing the ground truth losses, generating the teacher predictions, computing the knowledge distillation losses, determining the weighted loss, identifying the adjusted set of values of the plurality of learnable parameters of the student, and adjusting the values of the plurality of learnable parameters of the student.
  16. 16 . The device of claim 10 , wherein: the ground truth loss comprises a cross entropy loss; and the trained teacher model is trained to perform a natural language processing binary classification task on training data wherein each labeled training data sample comprises input data comprising a plurality of text tokens.
  17. 17 . A non-transitory processor-readable medium containing instructions which, when executed by a processor of a device, cause the device to: obtain a batch of training data comprising one or more labeled training data samples, each labeled training data sample having a respective ground truth label; process the batch of training data, using a student model comprising a plurality of learnable parameters, to generate, for input data in each data sample in the batch of training data, a student prediction; for each labeled training data sample in the batch of training data, process the student prediction and the ground truth label to compute a respective ground truth loss; process the batch of training data, using a trained teacher model, to generate, for each labeled training data sample in the batch of training data, a teacher prediction; for each labeled data sample in the batch of training data, process the student prediction and the teacher prediction to compute a respective knowledge distillation loss; determine a weighted loss that is a weighted sum of the knowledge distillation loss and ground truth loss for each labeled training data sample in the batch of training data by: determining a knowledge distillation weight for the labeled training data sample, the knowledge distillation weight being used to weight the knowledge distillation loss representing a difference between the student prediction and the teacher prediction for the labeled training data sample; determining a ground truth weight for the labeled training data sample, the ground truth weight being used to weight the ground truth loss representing a difference between the student prediction and the ground truth label for the labeled training data sample; and computing the weighted loss as the sum of: the knowledge distillation loss weighted by the knowledge distillation weight; and the ground truth loss weighted by the ground truth weight; perform gradient descent on the student model using the weighted loss to identify an adjusted set of values for the plurality of learnable parameters of the student; and adjust the values of the plurality of learnable parameters of the student to the adjusted set of values.
  18. 18 . The non-transitory process-readable medium of claim 17 , wherein the knowledge distillation weight and ground truth weight are determined based on user input.
  19. 19 . The non-transitory process-readable medium of claim 17 , wherein the knowledge distillation weight and ground truth weight are determined by a meta reweighting process.
  20. 20 . The non-transitory process-readable medium of claim 19 , wherein the meta reweighting process comprises: for each respective learnable parameter of the plurality of learnable parameters of the student, determining an optimized value of the respective learnable parameter as a function of a knowledge distillation perturbation variable and a ground truth perturbation variable with respect to the batch of training data; determining, for each labeled training data sample in the batch of training data, a respective estimated optimized value of the knowledge distillation perturbation variable and a respective estimated optimized value of the ground truth perturbation variable with respect to a batch of validation data; and for each labeled training data sample in the batch of training data, using the respective estimated optimized value of the knowledge distillation perturbation variable as the knowledge distillation weight and the respective estimated optimized value of the ground truth perturbation variable as the ground truth weight.

Description

TECHNICAL FIELD The present disclosure generally relates to knowledge distillation, and in particular to methods, devices and processor readable media for re-weighting to improve knowledge distillation. BACKGROUND Machine Learning (ML) is an artificial intelligence technique in which algorithms are used to build a model from sample data that is capable of being applied to input data to perform a specific inference task (i.e., making predictions or decisions based on new data) without being explicitly programmed to perform the specific inference task. Deep learning is one of the most successful and widely deployed machine learning algorithms. In deep learning, artificial neural networks typically consist of layers of non-linear parametric functions or “neurons”. To train the neural network using supervised learning, data samples are received by an input layer of the network and are processed by the neurons of the network to generate an output, such as inference data, at an output layer of the network. This is called forward propagation. The output pf the network is compared to semantic information associated with the data samples, such as semantic labels indicating a ground truth that can be compared to the inference data generated by the network. Training the neural network involves optimizing the learnable parameters of the neurons, typically using gradient-based optimization algorithms, to minimize a loss function. This process is called backpropagation. A particular configuration or architecture of an artificial neural network, or simply neural network (NN) is commonly referred to as a neural network model. Typically neural network models trained using supervised learning consist of upwards of billions of parameters and are therefore cumbersome to deploy, because storing the parameters and running these neural network models on resource constrained computing devices (e.g. computing devices having limited computing resources), such as mobile devices, embedded devices, or edge devices, is infeasible. Therefore, a variety of model compression and acceleration techniques have been developed to reduce either the number of parameters, or the memory required to store each parameter, of a neural network model prior to deployment on computing devices with limited computing resources. Knowledge Distillation (KD) is a compression technique used to transfer the knowledge of a large trained neural network model (i.e. a neural network model with many learned parameters) to a smaller neural network model (i.e. a neural network model with fewer learned parameters than the large trained neural network model). KD utilizes the generalization ability of the larger trained neural network model (referred to as the “teacher model” or “teacher”) using the inference data output by the larger trained model as “soft targets”, which are used as a supervision signal for training a smaller neural network model (called the “student model” or “student”). This technique stands in contrast to conventional supervised training of a neural network model, in which “hard targets” corresponding to the ground truth reflected in labelled training data of a training dataset are used as the sole supervision signal to train the neural network model. In KD, the student receives both soft targets and hard targets as supervision signals. This allows the student to achieve better performance after being trained on the same dataset as the teacher, as the soft targets provide higher entropy and less variance (i.e. better generalization) than the hard targets. A generalized technique for knowledge distillation is described by Geoffrey Hinton, Oriol Vinyals and Jeff Dean in Distilling the Knowledge in a Neural Network, https://arxiv.org/abs/1503.02531. FIG. 1 shows a typical configuration 10 for conventional KD. A teacher 20 is used to train a student 30. The teacher 20 receives input data 22 from a dataset used to train the student 30 and generates teacher inference data 24) based on the input data 22. The teacher inference data 24 is used as a soft target for supervision of the student 30, and the student 30 may be trained at least in part using a knowledge distillation loss function based on a comparison of the teacher inference data 24 to the student inference data 34 based on the same input data 22 provided to the teacher 20. The teacher 20 typically has high prediction accuracy for a given inference task, or scores highly on some other relevant metric, but is too computationally intensive for practical deployment to computing devices with limited computing resources. The student 30, on the other hand, may reasonably be deployed to a computing device with limited computing resources (e.g. memory and/or processing power). The teacher 20 is pre-trained, and the student 30 is generally trained using the same training dataset used to train the teacher 20; however, this pre-training of the teacher 20 is not always necessary. KD proposes training the studen