US-12619881-B2 - Method and system for training a neural network model using adversarial learning and knowledge distillation
Abstract
Method and system of training a student neural network using adversarial learning and knowledge distillation, including: training a generator to generate adversarial data samples for respective training data samples by masking parts of the training data samples with an objective of maximizing a divergence between output predictions generated by the student neural network and a teacher neural network model for the adversarial data samples; and training the student neural network based on objectives of (i) minimizing a divergence between output predictions generated by the student neural network and the teacher neural network model for the adversarial data samples, and (ii) minimizing a divergence between output predictions generated by the student neural network and the teacher neural network model for the training data samples.
Inventors
- Vasileios LIOUTAS
- Ahmad RASHID
- Mehdi Rezagholizadeh
Assignees
- Vasileios LIOUTAS
- Ahmad RASHID
- Mehdi Rezagholizadeh
Dates
- Publication Date
- 20260505
- Application Date
- 20230308
Claims (20)
- 1 . A method of training a student neural network model using adversarial learning and knowledge distillation comprising: training a generator to generate a respective adversarial data sample for each of a plurality of input data samples by replacing selected parts of the input data samples, the training of the generator being based on an objective of maximizing divergence between output predictions generated by the student neural network and a teacher neural network model for the adversarial data samples; and training the student neural network model based on objectives of: (i) minimizing divergence between output predictions generated by the student neural network model and the teacher neural network model for the adversarial data samples generated by the generator, and (ii) minimizing divergence between output predictions generated by the student neural network model and the teacher neural network model for the input data samples.
- 2 . The method of claim 1 wherein training the generator comprises: for each of the input data samples, randomly selecting and masking the selected parts of the input data sample and generating, using a generator neural network model, the respective adversarial data sample with replacement data for the selected parts of the masked input data; obtaining output predictions for the respective adversarial data samples from the teacher neural network model and the student neural network model; computing a generator loss using a generator loss function that is minimized when divergence between the output predictions generated by the student neural network and a teacher neural network model for the respective adversarial data samples is maximized; updating parameters of the generator neural network model using gradient descent based on the computed loss.
- 3 . The method of claim 2 wherein randomly selecting the selected parts of the input data sample is performed by randomly determining, based on a defined probability, for each part of an input data sample, if the part is to be selected as one of the selected parts.
- 4 . The method of claim 2 wherein generating the replacement data for the selected parts comprises sampling a Gumbel softmax distribution of logits generated by the generator neural network model.
- 5 . The method of claim 4 wherein the student neural network model is trained to perform a natural language processing task and each of the input data samples is a text data sample comprising a set of tokens that each correspond to a discrete text element of the text data sample, wherein the selected parts of the input data samples correspond to individual tokens.
- 6 . The method of claim 1 wherein the student neural network model is trained to perform an image processing task and each of the input data samples is an image data sample comprising a set of pixels, wherein the selected parts of the input data samples correspond to pixels.
- 7 . The method of claim 1 wherein divergence between the output predictions generated by the student neural network and a teacher neural network model for the respective adversarial data samples corresponds to a Kullback-Leibler (KL) divergence.
- 8 . The method of claim 1 wherein the divergence between the output predictions generated by the student neural network and a teacher neural network model for the input data samples corresponds to a Kullback-Leibler (KL) divergence.
- 9 . The method of claim 1 wherein training the student neural network model is also based on an objective of minimizing divergence between output predictions generated by the student neural network for the input data samples and ground truth labels for the input data samples.
- 10 . The method of claim 9 wherein the divergence between the output predictions generated by the student neural network and the ground truth labels for the input data samples corresponds to a Cross-Entropy loss.
- 11 . A system for training a student neural network model using adversarial learning and knowledge distillation, the system comprising one or more processers and a non-transitory storage medium storing software instructions that, when executed by the one or more processors, configure the system to perform a method comprising: training a generator to generate a respective adversarial data sample for each of a plurality of input data samples by replacing selected parts of the input data samples, the training of the generator being based on an objective of maximizing divergence between output predictions generated by the student neural network and a teacher neural network model for the adversarial data samples; and training the student neural network model based on objectives of: (i) minimizing divergence between output predictions generated by the student neural network model and the teacher neural network model for the adversarial data samples generated by the generator, and (ii) minimizing divergence between output predictions generated by the student neural network model and the teacher neural network model for the input data samples.
- 12 . The system of claim 11 wherein training the generator comprises: for each of the input data samples, randomly selecting and masking the selected parts of the input data sample and generating, using a generator neural network model, the respective adversarial data sample with replacement data for the selected parts of the masked input data; obtaining output predictions for the respective adversarial data samples from the teacher neural network model and the student teacher neural network model; computing a generator loss using a generator loss function that is minimized when divergence between the output predictions generated by the student network and a teacher neural network model for the respective adversarial data samples is maximized; updating parameters of the generator neural network model using gradient descent based on the computed loss.
- 13 . The system of claim 12 wherein randomly selecting the selected parts of the input data sample is performed by randomly determining, based on a defined probability, for each part of a input data sample, if the part is to be selected as one of the selected parts.
- 14 . The system of claim 12 wherein generating the replacement data for the selected parts comprises sampling a Gumbel softmax distribution of logits generated by the generator neural network model.
- 15 . The system of claim 14 wherein the student neural network model is trained to perform a natural language processing task and each of the input data samples is a text data sample comprising a set of tokens that each correspond to a discrete text element of the text data sample, wherein the selected parts of the input data samples correspond to individual tokens.
- 16 . The system of claim 11 wherein the student neural network model is trained to perform an image processing task and each of the input data samples is an image data sample comprising a set of pixels, wherein the selected parts of the input data samples correspond to pixels.
- 17 . The system of claim 11 wherein divergence between the output predictions generated by the student neural network and a teacher neural network model for the respective adversarial data samples corresponds to a Kullback-Leibler (KL) divergence.
- 18 . The system of claim 11 wherein the divergence between the output predictions generated by the student neural network and a teacher neural network model for the input data samples corresponds to a Kullback-Leibler (KL) divergence.
- 19 . The system of claim 11 wherein training the student neural network model is also based on an objective of minimizing divergence between output predictions generated by the student neural network for the input data samples and ground truth labels for the input data samples; and the divergence between the output predictions generated by the student neural network and the ground truth labels for the input data samples corresponds to a Cross-Entropy loss.
- 20 . A non-transitory computer readable medium storing software instructions that, when executed by the one or more processors, configure the one or more processors to perform a method of training a student neural network model using adversarial learning and knowledge distillation, comprising: training a generator to generate a respective adversarial data sample for each of a plurality of input data samples by replacing selected parts of the input data samples, the training of the generator being based on an objective of maximizing divergence between output predictions generated by the student neural network and a teacher neural network model for the adversarial data samples; and training the student neural network model based on objectives of: (i) minimizing divergence between output predictions generated by the student neural network model and the teacher neural network model for the adversarial data samples generated by the generator, and (ii) minimizing divergence between output predictions generated by the student neural network model and the teacher neural network model for the input data samples.
Description
RELATED APPLICATION DATA The present application is a continuation of International Patent Application No. PCT/CA2021/051249, filed Sep. 9, 2021 the content of which is incorporated herein by reference, which claims priority to, and the benefit of, provisional U.S. patent application No. 63/076,374, filed Sep. 9, 2020, the content of which is incorporated herein by reference. FIELD The present application relates to methods and systems for training machine learning models, and, in particular, methods and systems for training a neural network model using adversarial learning and knowledge distillation. BACKGROUND Deep learning based algorithms are machine learning methods used for many machine learning applications in natural language processing (NLP) and computer vision (CV) fields. Deep learning consists of composing layers of non-linear parametric functions or “neurons” together and training the parameters or “weights”, typically using gradient-based optimization algorithms, to minimize a loss function. One key reason of the success of these methods is the ability to improve performance with an increase in parameters and data. In NLP this has led to deep learning architectures with billions of parameters (Brown et. al 2020). Research has shown that large architectures or “models” are easier to optimize as well. Model compression is thus imperative for any practical application such as deploying a trained machine learning model on a phone for a personal assistant. Knowledge distillation (KD) is a neural network compression technique whereby the generalizations of a complex neural network model are transferred to a less complex neural network model that is able to make comparable inferences (i.e. predictions) as the complex model at less computing resource cost and time. Here, complex neural network model refers to a neural network model with a relatively high number of computing resources such as GPU/CPU power and computer memory space and/or those neural network models including a relatively high number of hidden layers. The complex neural network model, for the purposes of KD, is sometimes referred to as a teacher neural network model (T) or a teacher for short. A typical drawback of the teacher is that it may require significant computing resources that may not be available in consumer electronic devices, such as mobile communication devices or edge computing devices. Furthermore, the teacher neural network model typically requires a significant amount of time to infer (i.e. predict) a particular output for an input due to the complexity of the teacher neural network model itself, and hence the teacher neural network model may not be suitable for deployment to a consumer computing device for use therein. Thus, KD techniques are applied to extract, or distill the learned parameters, or knowledge, of a teacher neural network model and impart such knowledge to a less sophisticated neural network model with faster inference time and reduced computing resource and memory space cost that may with less effort on consumer computing devices, such as edge devices. The less complex neural network model is often referred to as the student neural network model (S) or a student for short. The KD techniques involve training the student using not only the labeled training data samples of the training dataset but also using the outputs generated by the teacher neural network model, known as logits. An example of a KD loss function used for training a student neural network model is as follows: LKD=α*H(y,σ(zs;T=1))+(1−α)*H(σ(zt;T=τ),σ(zs;T=τ)) (1) where H is the cross-entropy loss function (other loss functions may also be used), σ is the softmax function, parameter T is a temperature parameter, α is a hyperparamter that controls the amount of contribution from the cross entropy loss function and KD loss, and zt and zs are the logits (i.e. the output of the neural network before the last softmax layer) of the teacher neural network model (T) and student neural network model (S) respectively. KD techniques are widely used because they are agnostic to the architectures of the neural networks of the teacher and the student neural network models and require only access to the outputs generated by the teacher neural network model in order to train the student neural network model to effectively imitate the behavior of the teacher neural network model. Still, for many applications there is a significant gap between the performance of the teacher neural network model and the performance of the student neural network model and various KD techniques have been proposed to reduce this gap. For example, in the NLP field, KD techniques have been proposed whereby knowledge transfer can be effected by learning parameters of the teacher and student neural network models in stages and freezing the parameters, or by defining intermediate neural networks which learns from the teacher neural network model and teaches the student neural n