Search

US-20260127509-A1 - TRAINING DISTILLED MACHINE LEARNING MODELS

US20260127509A1US 20260127509 A1US20260127509 A1US 20260127509A1US-20260127509-A1

Abstract

Methods, systems, and apparatus, including computer programs encoded on computer storage media, for training a distilled machine learning model. One of the methods includes training a cumbersome machine learning model, wherein the cumbersome machine learning model is configured to receive an input and generate a respective score for each of a plurality of classes; and training a distilled machine learning model on a plurality of training inputs, wherein the distilled machine learning model is also configured to receive inputs and generate scores for the plurality of classes, comprising: processing each training input using the cumbersome machine learning model to generate a cumbersome target soft output for the training input; and training the distilled machine learning model to, for each of the training inputs, generate a soft output that matches the cumbersome target soft output for the training input.

Inventors

  • Oriol Vinyals
  • Jeffrey Adgate Dean
  • Geoffrey E. Hinton

Assignees

  • GOOGLE LLC

Dates

Publication Date
20260507
Application Date
20251231

Claims (20)

  1. 1 . A method performed by one or more computers, the method comprising: training a student machine learning model having a plurality of student model parameters on a set of multiple training inputs, wherein the student machine learning model is configured to process an input to generate an output; wherein the training comprises: processing each training input in the set of multiple training inputs using a teacher machine learning model that is one of a plurality of teacher machine learning models to generate a respective target output; processing each training input in the set of multiple training inputs using the student machine learning model to generate a respective student output; and training the student machine learning model to optimize an objective function which includes a term that, for each training input in the set of multiple training inputs, measures a discrepancy between: (i) the target output for the training input, and (ii) the student output for the training input.
  2. 2 . The method of claim 1 , wherein the plurality of teacher machine learning models are trained separately.
  3. 3 . The method of claim 1 , wherein each of the plurality of teacher machine learning models is trained to perform a respective task.
  4. 4 . The method of claim 3 , wherein each of the plurality of teacher machine learning models is trained to perform a different respective task.
  5. 5 . The method of claim 3 , wherein the respective task comprises generating a respective score for a subset of a plurality of classes.
  6. 6 . The method of claim 5 , further comprising: generating one or more subsets of the plurality of classes by clustering the classes that are frequently predicted together by one or more full neural networks into the same subset; and assigning each subset to a respective teacher machine learning model.
  7. 7 . The method of claim 1 , wherein processing each training input in the set of multiple training inputs using a teacher machine learning model that is one of a plurality of teacher machine learning models to generate a respective target output comprises, for each training input: processing the training input using each of one or more of the plurality of teacher machine learning models to generate a respective teacher output; and generating the respective target output based on the one or more respective teacher outputs.
  8. 8 . The method of claim 1 , wherein the respective target output comprises a respective teacher score distribution over a plurality of classes.
  9. 9 . The method of claim 1 , wherein the respective student output comprises a respective student score distribution over a plurality of classes.
  10. 10 . A system comprising: one or more computers; and one or more storage devices communicatively coupled to the one or more computers, wherein the one or more storage devices store instructions that, when executed by the one or more computers, cause the one or more computers to perform operations comprising: training a student machine learning model having a plurality of student model parameters on a set of multiple training inputs, wherein the student machine learning model is configured to process an input to generate an output; wherein the training comprises: processing each training input in the set of multiple training inputs using a teacher machine learning model that is one of a plurality of teacher machine learning models to generate a respective target output; processing each training input in the set of multiple training inputs using the student machine learning model to generate a respective student output; and training the student machine learning model to optimize an objective function which includes a term that, for each training input in the set of multiple training inputs, measures a discrepancy between: (i) the target output for the training input, and (ii) the student output for the training input.
  11. 11 . The system of claim 10 , wherein the plurality of teacher machine learning models are trained separately.
  12. 12 . The system of claim 10 , wherein each of the plurality of teacher machine learning models is trained to perform a respective task.
  13. 13 . The system of claim 12 , wherein each of the plurality of teacher machine learning models is trained to perform a different respective task.
  14. 14 . The system of claim 12 , wherein the respective task comprises generating a respective score for a subset of a plurality of classes.
  15. 15 . The system of claim 14 , further comprising: generating one or more subsets of the plurality of classes by clustering the classes that are frequently predicted together by one or more full neural networks into the same subset; and assigning each subset to a respective teacher machine learning model.
  16. 16 . The system of claim 10 , wherein processing each training input in the set of multiple training inputs using a teacher machine learning model that is one of a plurality of teacher machine learning models to generate a respective target output comprises, for each training input: processing the training input using each of one or more of the plurality of teacher machine learning models to generate a respective teacher output; and generating the respective target output based on the one or more respective teacher outputs.
  17. 17 . The system of claim 10 , wherein the respective target output comprises a respective teacher score distribution over a plurality of classes.
  18. 18 . The system of claim 10 , wherein the respective student output comprises a respective student score distribution over a plurality of classes.
  19. 19 . One or more non-transitory computer storage media storing instructions that when executed by one or more computers cause the one or more computers to perform operations comprising: training a student machine learning model having a plurality of student model parameters on a set of multiple training inputs, wherein the student machine learning model is configured to process an input to generate an output; wherein the training comprises: processing each training input in the set of multiple training inputs using a teacher machine learning model that is one of a plurality of teacher machine learning models to generate a respective target output; processing each training input in the set of multiple training inputs using the student machine learning model to generate a respective student output; and training the student machine learning model to optimize an objective function which includes a term that, for each training input in the set of multiple training inputs, measures a discrepancy between: (i) the target output for the training input, and (ii) the student output for the training input.
  20. 20 . The non-transitory computer storage media of claim 19 , wherein the plurality of teacher machine learning models are trained separately.

Description

CROSS-REFERENCE TO RELATED APPLICATION This is a continuation application of, and claims priority to, U.S. patent application Ser. No. 18/399,358, filed on Dec. 28, 2023, which is a continuation of U.S. patent application Ser. No. 17/863,733, filed Jul. 13, 2022, now U.S. Pat. No. 11,900,232, which is a continuation of U.S. patent application Ser. No. 16/841,859, titled “TRAINING DISTILLED MACHINE LEARNING MODELS,” filed on Apr. 7, 2020, now U.S. Pat. No. 11,423,337, which is a continuation application of, and claims priority to, U.S. patent application Ser. No. 16/368,526, titled “TRAINING DISTILLED MACHINE LEARNING MODELS,” filed on Mar. 28, 2019, now U.S. Pat. No. 10,650,328, which is a continuation application of, and claims priority to, U.S. patent application Ser. No. 14/731,349, titled “TRAINING DISTILLED MACHINE LEARNING MODELS,” filed on Jun. 4, 2015, now U.S. Pat. No. 10,289,962, which claims the benefit of priority to U.S. Provisional Application No. 62/008,998, filed on Jun. 6, 2014. The disclosure of the prior applications are considered part of and are incorporated by reference in the disclosure of this application. BACKGROUND This specification relates to training machine learning models. A machine learning model receives input and generates an output based on the received input and on values of the parameters of the model. For example, machine learning models may receive an image and generate a score for each of a set of classes, with the score for a given class representing a probability that the image contains an image of an object that belongs to the class. The machine learning model may be composed of, e.g., a single level of linear or non-linear operations or may be a deep network, i.e., a machine learning model that is composed of multiple levels, one or more of which may be layers of non-linear operations. An example of a deep network is a neural network with one or more hidden layers. SUMMARY In general, this specification describes techniques for training a distilled machine learning model using a cumbersome machine learning model. Particular embodiments of the subject matter described in this specification can be implemented so as to realize one or more of the following advantages. A distilled machine learning model that is easier to deploy than a cumbersome machine learning model, i.e., because it requires less computation, memory, or both, to generate outputs at run time than the cumbersome machine learning model, can effectively be trained using a cumbersome neural network that has already been trained. Once trained using the cumbersome machine learning model, the distilled machine learning model can generate outputs that are not significantly less accurate than outputs generated by the cumbersome machine learning model despite being easier to deploy or using fewer computational resources than the cumbersome machine learning model. An ensemble model that includes one or more full machine learning models and one or more specialist machine learning models can more accurately generate scores to classify a received input. In particular, by including specialist machine learning models in the ensemble model, the scores for classes that are frequently predicted together or confused by the full machine learning models can be more accurately generated. The details of one or more embodiments of the subject matter of this specification are set forth in the accompanying drawings and the description below. Other features, aspects, and advantages of the subject matter will become apparent from the description, the drawings, and the claims. BRIEF DESCRIPTION OF THE DRAWINGS FIG. 1 shows an example distilled machine learning model training system. FIG. 2 is a flow diagram of an example process for training a distilled machine learning model using a trained cumbersome machine learning model. FIG. 3 shows an example machine learning model system. FIG. 4 is a flow diagram of an example process for processing an input using an ensemble machine learning model that includes one or more full machine learning models and one or more specialist machine learning models. Like reference numbers and designations in the various drawings indicate like elements. DETAILED DESCRIPTION FIG. 1 is a block diagram of an example distilled machine learning model training system 100 for training a distilled machine learning model 120. The distilled machine learning model training system 100 is an example of a system implemented as computer programs on one or more computers in one or more locations, in which the systems, components, and techniques described below are implemented. The distilled machine learning model training system 100 trains the distilled machine learning model 120 using a trained cumbersome machine learning model 110. Generally, a machine learning model receives input and generates an output based on the received input and on values of the parameters of the model. In particular, both the distilled mach