Search

CN-116824288-B - Model training method, category prediction method and device

CN116824288BCN 116824288 BCN116824288 BCN 116824288BCN-116824288-B

Abstract

The application discloses a model training method, a class prediction device, a computing device and a computer readable storage medium, which are used for solving the problem that the class prediction accuracy of a model is not high enough. The training method comprises the steps of training a class prediction first model by using an original sample object set, determining the matching degree between labels of sample objects and predicted classes of the sample objects by using the trained first model, selecting sample objects meeting preset conditions from the original sample object set, and training a class prediction second model by using the selected sample objects. Because the sample object used for training the class prediction second model is a sample object with relatively high labeling accuracy of the selected label, compared with the class prediction model obtained by training directly based on a dirty data set, the class prediction model obtained by training the method provided by the embodiment of the application has higher accuracy.

Inventors

  • FENG YIJUN
  • ZHU HANGCHENG
  • MA GUOJUN

Assignees

  • 北京字跳网络技术有限公司

Dates

Publication Date
20260505
Application Date
20220317

Claims (15)

  1. 1. A method of model training, comprising: training a class prediction first model by using sample objects in an original sample object set to obtain a trained first model, wherein the sample objects comprise multimedia files which are used as training samples after labeling tags, and the tags represent the classes of the sample objects; Performing category prediction on the sample object by using the trained first model to determine the matching degree between the label of the sample object and the predicted category of the sample object; according to the matching degree, selecting a sample object meeting a preset condition from the original sample object set as a first sample object; training the class prediction second model by using the first sample object to obtain a trained second model; Training a student model by using the trained second model as a teacher model and adopting a knowledge distillation technology to obtain a trained student model, wherein in the training process of the student model, the teacher model and the student model conduct category prediction on the same batch of sample objects selected from the recombined sample object set; The recombined sample object set consists of the first sample object and a second sample object which does not accord with the preset condition, and the second sample object is marked with a preset label.
  2. 2. The method of claim 1, wherein training the student model using knowledge distillation techniques with the trained second model as a teacher model using sample objects in the reorganized set of sample objects comprises: Using the trained second model as the teacher model, and performing category prediction on sample objects selected from the recombined sample object set to obtain category distribution predicted by the teacher model; Discarding the category distribution of the second sample object in the category distribution predicted by the student model to obtain the remaining category distribution of the first sample object as the target category distribution predicted by the student model; determining a loss function of the student model based on the class distribution predicted by the teacher model, the class distribution predicted by the student model, and the target class distribution; and updating the student model to be converged according to the loss function.
  3. 3. The method of claim 2, wherein determining a loss function for the student model based on the teacher model predicted class distribution, the student model predicted class distribution, and the target class distribution comprises: acquiring a first loss function based on the target class distribution; acquiring a second loss function based on the class distribution predicted by the teacher model and the class distribution predicted by the student model; And acquiring the loss function of the student model based on the first loss function and the second loss function.
  4. 4. The method of claim 3, wherein the obtaining a first loss function based on the target class distribution, obtaining a second loss function based on the class distribution predicted by the teacher model and the class distribution predicted by the student model, and obtaining the loss function of the student model based on the first loss function and the second loss function, comprises: constructing a cross entropy function between the target class distribution and the label of the corresponding first sample object as a first loss function according to the target class distribution; Constructing a KL divergence loss function of the class distribution predicted by the teacher model and the class distribution predicted by the student model as a second loss function according to the class distribution predicted by the teacher model and the class distribution predicted by the student model; and determining the sum of the first loss function and the second loss function as the loss function of the student model.
  5. 5. The method of claim 1, wherein performing class prediction on the sample object using the trained first model to determine a degree of matching between a tag of the sample object and a predicted class of the sample object comprises: and carrying out category prediction on the sample object by using the trained first model to obtain a category confidence value of the sample object, wherein the category confidence value characterizes the matching degree and is positively correlated with the matching degree.
  6. 6. The method of claim 5, wherein selecting, from the original set of sample objects, a sample object having the matching degree meeting a predetermined matching degree requirement according to the matching degree, comprises: According to the class confidence values of the positive sample objects in the original sample object set, sequentially selecting the positive sample objects with the number according with the first duty ratio according to the sequence from the large value to the small value; And sequentially selecting the negative sample objects with the number according with the second duty ratio according to the order of the values from small to large according to the class confidence values of the negative sample objects in the original sample object set.
  7. 7. A method of model training, comprising: The method comprises the steps of obtaining a recombined sample object set, wherein sample objects in the recombined sample object set comprise a first sample object conforming to a preset condition and a second sample object not conforming to the preset condition, and the second sample object is marked with a preset label; Using a classification model trained based on the first sample object as a teacher model, and performing class prediction on sample objects in the recombined sample object set to obtain class distribution predicted by the teacher model; Discarding the class distribution of the second sample object in the class distribution predicted by the student model to obtain the residual class distribution of the first sample object as the target class distribution predicted by the student model; Determining a loss function of the student model according to the class distribution predicted by the teacher model, the class distribution predicted by the student model and the target class distribution; and updating the student model to be converged according to the loss function to obtain a trained student model.
  8. 8. The method of claim 7, wherein determining a loss function for the student model based on the teacher model predicted class distribution, the student model predicted class distribution, and the target class distribution comprises: acquiring a first loss function based on the target class distribution; acquiring a second loss function based on the class distribution predicted by the teacher model and the class distribution predicted by the student model; And acquiring the loss function of the student model based on the first loss function and the second loss function.
  9. 9. The method of claim 8, wherein the obtaining a first loss function based on the target class distribution, obtaining a second loss function based on the class distribution predicted by the teacher model and the class distribution predicted by the student model, and obtaining the loss function of the student model based on the first loss function and the second loss function, comprises: constructing a cross entropy function between the target class distribution and the label of the corresponding first sample object as a first loss function according to the target class distribution; Constructing a KL divergence loss function of the class distribution predicted by the teacher model and the class distribution predicted by the student model as a second loss function according to the class distribution predicted by the teacher model and the class distribution predicted by the student model; and determining the sum of the first loss function and the second loss function as the loss function of the student model.
  10. 10. A class prediction method, comprising: Acquiring an object to be subjected to category prediction; inputting the object into a category prediction model to obtain a category prediction result of the object output by the category prediction model; The class prediction model is trained by the model training method according to any one of claims 1 to 9.
  11. 11. A model training device, comprising: The system comprises a first training unit, a first model, a second training unit and a first model, wherein the first training unit is used for training a class prediction first model by utilizing sample objects in an original sample object set to obtain a trained first model; The prediction unit is used for carrying out category prediction on the sample object by using the trained first model obtained by the first training unit so as to determine the matching degree between the label of the sample object and the predicted category of the sample object; The selecting unit is used for selecting a sample object meeting preset conditions from the original sample object set as a first sample object according to the matching degree determined by the predicting unit; The training unit is used for training the class prediction second model by using the first sample object to obtain a trained second model, training the student model by using the trained second model as a teacher model by using the sample object in the recombined sample object set and adopting a knowledge distillation technology to obtain a trained student model, wherein in the training process of the student model, the teacher model and the student model conduct class prediction on the same batch of sample objects selected from the recombined sample object set; The recombined sample object set consists of the first sample object and a second sample object which does not accord with the preset condition, and the second sample object is marked with a preset label.
  12. 12. A model training device, comprising: The sample acquisition unit is used for acquiring a recombined sample object set, wherein the sample objects in the recombined sample object set comprise a first sample object conforming to a preset condition and a second sample object not conforming to the preset condition, and the second sample object is marked with a preset label; The class distribution prediction unit is used for predicting the class of the sample objects in the recombined sample object set by using a class prediction model trained based on the first sample object as a teacher model so as to obtain class distribution predicted by the teacher model; discarding the class distribution of the second sample object in the class distribution predicted by the student model to obtain the residual class distribution of the first sample object as the target class distribution predicted by the student model; A loss function determining unit, configured to determine a loss function of the student model according to the class distribution predicted by the teacher model, the class distribution predicted by the student model, and the target class distribution; And the model training unit is used for updating the student model to be converged according to the loss function determined by the loss function determining unit so as to obtain a trained student model.
  13. 13. A class prediction device, comprising: the object acquisition unit is used for acquiring an object to be subjected to category prediction; The class prediction unit is used for inputting the object acquired by the object acquisition unit into a class prediction model so as to acquire a class prediction result of the object output by the class prediction model; The class prediction model is trained by the model training method according to any one of claims 1 to 9.
  14. 14. A computing device includes a memory and a processor, wherein, The memory is used for storing programs; The processor, coupled to the memory, is configured to execute the program stored in the memory, for performing the method of any one of claims 1-9, or for performing the method of claim 10.
  15. 15. A computer readable storage medium storing a computer program which, when executed by a computer, is capable of carrying out the method of any one of claims 1 to 9 or of carrying out the method of claim 10.

Description

Model training method, category prediction method and device Technical Field The present application relates to the field of artificial intelligence, and in particular, to a model training method, a class prediction device, a computing device, and a computer readable storage medium. Background Video understanding refers to identifying content in a video. Currently, an artificial intelligence (ARTIFICIAL INTELLIGENCE) model can be used for video understanding, and compared with pure manual video understanding, the method has obvious advantages in processing efficiency, so that the method is widely applied in the industry. The AI model is adopted for video understanding, and the AI model needs to be trained based on the marked sample video. When a worker marks a label on a sample, the label classification paradigm is easy to be influenced by subjective factors to continuously change, so that the condition that different labels can be generated on the same video occurs. If such sample video is used directly for AI model training, a high accuracy model will not be available. Disclosure of Invention The embodiment of the application provides a model training method which is used for solving the problem that model training is carried out by adopting sample objects with inaccurate label labels, so that the accuracy of category prediction of a model obtained by training is not high enough. The embodiment of the application also provides a model training device, a category prediction method, a category prediction device, computing equipment and a computer storage medium. The embodiment of the application adopts the following technical scheme: A model training method, comprising: Training the class prediction first model by using sample objects in the original sample object set to obtain a trained first model, wherein the sample objects are objects used as training samples for marking; Performing category prediction on the sample object by using the trained first model to determine the matching degree between the label of the sample object and the predicted category of the sample object; according to the matching degree, selecting a sample object meeting a preset condition from the original sample object set as a first sample object; and training a category prediction second model by using the first sample object. A model training method, comprising: The method comprises the steps of obtaining a recombined sample object set, wherein sample objects in the recombined sample object set comprise a first sample object conforming to a preset condition and a second sample object not conforming to the preset condition, and the second sample object is marked with a preset label; Using a classification model trained based on the first sample object as a teacher model, and performing class prediction on sample objects in the recombined sample object set to obtain class distribution predicted by the teacher model; Discarding the class distribution of the second object in the class distribution predicted by the student model to obtain the class distribution of the residual first sample object as the target class distribution predicted by the student model; Determining a loss function of the student model according to the class distribution predicted by the teacher model, the class distribution predicted by the student model and the target class distribution; and updating the student model to be converged according to the loss function to obtain a trained student model. A class prediction method comprising: Acquiring an object to be subjected to category prediction; inputting the object into a category prediction model to obtain a category prediction result of the object output by the category prediction model; The category prediction model is obtained by training by adopting any model training method. A model training apparatus comprising: the first training unit is used for training the class prediction first model by utilizing sample objects in the original sample object set to obtain a trained first model, wherein the sample objects are objects used as training samples after labeling; The prediction unit is used for carrying out category prediction on the sample object by using the trained first model obtained by the first training unit so as to determine the matching degree between the label of the sample object and the predicted category of the sample object; The selecting unit is used for selecting a sample object meeting preset conditions from the original sample object set as a first sample object according to the matching degree determined by the predicting unit; And the second training unit is used for training the category prediction second model by using the first sample object selected by the selecting unit. A model training apparatus comprising: The sample acquisition unit is used for acquiring a recombined sample object set, wherein the sample objects in the recombined sample object set comprise a first sample obje