Search

US-12619874-B2 - Stochastic gradient boosting for deep neural networks

US12619874B2US 12619874 B2US12619874 B2US 12619874B2US-12619874-B2

Abstract

Aspects described herein may allow for the application of stochastic gradient boosting techniques to the training of deep neural networks by disallowing gradient back propagation from examples that are correctly classified by the neural network model while still keeping correctly classified examples in the gradient averaging. Removing the gradient contribution from correctly classified examples may regularize the deep neural network and prevent the model from overfitting. Further aspects described herein may provide for scheduled boosting during the training of the deep neural network model conditioned on a mini-batch accuracy and/or a number of training iterations. The model training process may start un-boosted, using maximum likelihood objectives or another first loss function. Once a threshold mini-batch accuracy and/or number of iterations are reached, the model training process may begin using boosting by disallowing gradient back propagation from correctly classified examples while continue to average over all mini-batch examples.

Inventors

  • Oluwatobi Olabiyi
  • Erik T. Mueller
  • Christopher Larson

Assignees

  • CAPITAL ONE SERVICES, LLC

Dates

Publication Date
20260505
Application Date
20240215

Claims (17)

  1. 1 . A computer-implemented method comprising: initializing a deep neural network model to include an input layer, an output layer, a plurality of hidden layers, and a plurality of model parameters; training, based on a training set that comprises a plurality of examples, the deep neural network model by: performing a first plurality of iterations, wherein each respective iteration of the first plurality of iterations updates the plurality of model parameters based on a first loss function for each example in a respective mini-batch from the plurality of examples; determining that a burn-in threshold is satisfied based on a first number of training iterations performed or based on an accuracy of the deep neural network model; based on determining that the burn-in threshold is satisfied, performing a second plurality of iterations, wherein each respective iteration of the second plurality of iterations updates the plurality of model parameters based on minimizing an average gradient of a second loss function for each example in a respective mini-batch from the plurality of examples, wherein the second loss function is a weighted negative log-likelihood of the form: ℒ DB ( θ t , 𝒟 t ) = 1 ❘ "\[LeftBracketingBar]" 𝒟 t ❘ "\[RightBracketingBar]" ⁢ ∑ ( x , y *) ∈ 𝒟 t - ( 1 - λ ( y ^ , y *) ) ⁢ log ⁢ p θ ( y * ❘ "\[LeftBracketingBar]" x ) wherein θ t corresponds to the plurality of model parameters, t corresponds to the respective mini-batch, x and y* are inputs and outputs of examples in t , p θ (y*|x) is a conditional probability of output y* given x based on the plurality of model parameters, and λ(ŷ, y*) corresponds to a weighting factor that is based on a similarity between predicted output ŷ and ground truth y*, and wherein the weighted negative log-likelihood is weighted such that correctly classified examples are given zero weight, and the average gradient of the weighted negative log-likelihood is determined based on a total size of the respective mini-batch used in a given iteration; after the training, determining that the model, based on the plurality of model parameters, satisfies one or more stopping criteria; and after determining that the model satisfies one or more stopping criteria, generating, based on an input data set, one or more predictions using the model.
  2. 2 . The computer-implemented method of claim 1 , wherein the one or more stopping criteria comprises a maximum number of training iterations.
  3. 3 . The computer-implemented method of claim 1 , wherein the one or more stopping criteria comprises a threshold accuracy of the deep neural network model for a validation training set.
  4. 4 . The method of claim 1 , wherein the one or more stopping criteria comprises a threshold accuracy of the deep neural network model for a mini-batch.
  5. 5 . The computer-implemented method of claim 1 , wherein the one or more stopping criteria is based on determining that a given iteration did not result in updated model parameters due to each example being correctly classified.
  6. 6 . The computer-implemented method of claim 1 , wherein the deep neural network model is configured to generate predictions regarding speech recognition.
  7. 7 . A system comprising: a database configured to store a training set that comprises a plurality of examples, wherein each example, of the plurality of examples, comprises an input, a ground truth output, and is associated with a conditional probability of the ground truth output given the input; one or more processors; and memory storing instructions that, when executed by the one or more processors, cause the one or more processors to: initialize a deep neural network model to include an input layer, an output layer, a plurality of hidden layers, and a plurality of model parameters; train, based on a training set that comprises a plurality of examples, the deep neural network model by causing the one or more processors to: perform a first plurality of iterations, wherein each respective iteration of the first plurality of iterations updates the plurality of model parameters based on a first loss function for each example in a respective mini-batch from the plurality of examples; determine that a burn-in threshold is satisfied based on a first number of training iterations performed or based on an accuracy of the deep neural network model; based on determining that the burn-in threshold is satisfied, perform a second plurality of iterations, wherein each respective iteration of the second plurality of iterations updates the plurality of model parameters based on minimizing an average gradient of a second loss function for each example in a respective mini-batch from the plurality of examples, wherein the second loss function is a weighted negative log-likelihood that is based on the conditional probability of an output given an input based on the plurality of model parameters and a weighting factor, and wherein the weighting factor is based on a similarity between predicted output and ground truth, the weighted negative log-likelihood is weighted such that correctly classified examples are given zero weight, and the average gradient of the weighted negative log-likelihood is determined based on a total size of the respective mini-batch used in a given iteration; after the training, determine that the model, based on the plurality of model parameters, satisfies one or more stopping criteria; and after determining that the model satisfies one or more stopping criteria, generate, based on an input data set, one or more predictions using the model.
  8. 8 . The system of claim 7 , wherein the one or more stopping criteria comprises a maximum number of training iterations.
  9. 9 . The system of claim 7 , wherein the one or more stopping criteria comprises a threshold accuracy of the deep neural network model for a validation training set.
  10. 10 . The system of claim 7 , wherein the one or more stopping criteria comprises a threshold accuracy of the deep neural network model for a mini-batch.
  11. 11 . The system of claim 7 , wherein the one or more stopping criteria is based on determining that a given iteration did not result in updated model parameters due to each example being correctly classified.
  12. 12 . The system of claim 7 , wherein the deep neural network model is configured to generate predictions regarding speech recognition.
  13. 13 . One or more non-transitory computer-readable media storing instructions that, when executed by one or more processors, cause the one or more processors to perform steps comprising: initializing a deep neural network model to include an input layer, an output layer, a plurality of hidden layers, and a plurality of model parameters; training, based on a training set that comprises a plurality of examples, the deep neural network model by: performing a first plurality of iterations, wherein each respective iteration of the first plurality of iterations updates the plurality of model parameters based on a first loss function for each example in a respective mini-batch from the plurality of examples; determining that a burn-in threshold is satisfied based on a first number of training iterations performed or based on an accuracy of the deep neural network model; based on determining that the burn-in threshold is satisfied, performing a second plurality of iterations, wherein each respective iteration of the second plurality of iterations updates the plurality of model parameters based on minimizing an average gradient of a second loss function for each example in a respective mini-batch from the plurality of examples, wherein the second loss function is a weighted negative log-likelihood that is based on the conditional probability of an output given an input based on the plurality of model parameters and a weighting factor, and wherein the weighting factor is based on a similarity between predicted output and ground truth, the weighted negative log-likelihood is weighted such that correctly classified examples are given zero weight, and the average gradient of the weighted negative log-likelihood is determined based on a total size of the respective mini-batch used in a given iteration; after the training, determining that the model, based on the plurality of model parameters, satisfies one or more stopping criteria; and after determining that the model satisfies one or more stopping criteria, generating, based on an input data set, one or more predictions using the model.
  14. 14 . The one or more non-transitory computer-readable media of claim 13 , wherein the one or more stopping criteria comprises a maximum number of training iterations.
  15. 15 . The one or more non-transitory computer-readable media of claim 13 , wherein the one or more stopping criteria comprises a threshold accuracy of the deep neural network model for a validation training set or for a mini-batch.
  16. 16 . The one or more non-transitory computer-readable media of claim 13 , wherein the one or more stopping criteria is based on determining that a given iteration did not result in updated model parameters due to each example being correctly classified.
  17. 17 . The one or more non-transitory computer-readable media of claim 13 , wherein the deep neural network model is configured to generate predictions regarding speech recognition.

Description

CROSS-REFERENCE TO RELATED APPLICATIONS This application is a continuation of U.S. application Ser. No. 17/232,968, filed on Apr. 16, 2021, which is a continuation of U.S. application Ser. No. 16/293,047, filed on Mar. 5, 2019 (now U.S. Pat. No. 10,990,878), which is a continuation of prior U.S. application Ser. No. 16/276,306, filed on Feb. 14, 2019 (now U.S. Pat. No. 10,510,002). Each of the above-mentioned applications is hereby incorporated herein by reference in its entirety. A portion of the disclosure of this patent document contains material which is subject to copyright protection. The copyright owner has no objection to the facsimile reproduction by anyone of the patent document or the patent disclosure, as it appears in the Patent and Trademark Office patent file or records, but otherwise reserves all copyright rights whatsoever. FIELD OF USE Aspects of the disclosure relate generally to machine learning. More specifically, aspects of the disclosure may provide for enhanced training of models that use a deep neural network architecture based on features similar to stochastic gradient boosting. BACKGROUND Deep neural network models may contain millions of parameters that extract hierarchies of features from data, enabling them to learn from a large amount of data compared to earlier shallow networks. However, deep neural networks often suffer from overfitting and a lack of generalization due to their large capacity. This may result from learning stages throughout the model training process. Due to the nature of deep neural networks, models may learn based on (i) connecting input and output labels by extracting predictive features from the input; (ii) statistics associated with output labels (e.g., likelihood of the output itself); and (iii) connecting non-predictive features in the input to output labels. It is desirable that models focus on the predictive features of (i) and avoid learning from non-predictive aspects (ii) and (iii). Structuring model training processes so the model learns in this way has proven difficult, as deep neural networks typically maximize the conditional probability P(y|x) of the output (y) given input features (x), instead of maximizing mutual information, P(y|x)/P(y) between the output and input. Stochastic gradient boosting has been used in machine learning to combine the capacity of multiple shallow or weak learners to form a deep or strong learner. A data set may be split among multiple weak learners, and weak models may specialize on fractions of the data set. Application of stochastic gradient boosting to an ensemble of decision trees is described in J. Friedman, “Greedy Function Approximation: A Gradient Boosting Machine,” The Annals of Statistics, Vol. 29, No. 5, 2011, which is incorporated herein by reference. But stochastic gradient boosting has been considered infeasible for application to training deep neural networks. It has been observed that application of Friedman's stochastic gradient boosting to deep neural network training often led to training instability. See, e.g., Philip M. Long, et al, “Random Classification Noise Defeats All Convex Potential Boosters,” in Proceedings of the 25th International Conference on Machine Learning, Helsinki, Finland, 2008. Since deep neural networks are strong learners by design, model gradients are generally not boosted during computation as it has been seen as computationally prohibitive. And other gradient descent-based boosting algorithms suffer from a labelling noise problem that hinders model training. Aspects described herein may address these and other problems, and generally improve the quality, efficiency, and speed of machine learning systems by offering improved model training through regularizing model training, improving network generalization, and abating the deleterious effect of class imbalance on model performance. SUMMARY The following presents a simplified summary of various aspects described herein. This summary is not an extensive overview, and is not intended to identify key or critical elements or to delineate the scope of the claims. The following summary merely presents some concepts in a simplified form as an introductory prelude to the more detailed description provided below. Aspects described herein may allow for the application of stochastic gradient boosting techniques to the training of deep neural networks. This may have the effect of regularizing model training, improving network generalization, and abating the deleterious effect of class imbalance on model performance. According to some aspects, these and other benefits may be achieved by disallowing gradient back propagation from examples that are correctly classified by the neural network model while still keeping correctly classified examples in the gradient averaging. In implementation, this may be effected by multiplying the contribution of correctly classified examples to a loss function by a weighting factor of 0 while still av