Search

CN-116976401-B - Model training method, device, equipment, storage medium and program product

CN116976401BCN 116976401 BCN116976401 BCN 116976401BCN-116976401-B

Abstract

The application discloses a model training method, a model training device, model training equipment, a model training storage medium and a model training program product, and belongs to the technical field of machine learning. The method comprises the steps of generating a pseudo sample based on a first generation model in an countermeasure generation network, determining a prediction label of the pseudo sample through a first discrimination model in the countermeasure generation network, performing iterative training on weights of a plurality of loss functions of the first generation model based on the prediction label and an actual label of the pseudo sample until the weights of the plurality of loss functions meet a first convergence condition, determining that the iterative training of the weights of the plurality of loss functions is finished, and determining a first target model based on the weights of the plurality of loss functions at the end of the iterative training. The method can learn the most suitable weight for each loss function in the iterative training process, and ensures the optimality of the weight of each loss function, thereby improving the use effect of the finally obtained target model.

Inventors

  • LI WENJIN

Assignees

  • OPPO广东移动通信有限公司

Dates

Publication Date
20260505
Application Date
20220419

Claims (13)

  1. 1. A method of model training, the method comprising: Generating a second training sample based on a first generation model in the countermeasure generation network, the second training sample being a dummy sample; determining a predictive label of the second training sample through a first discriminant model in the countermeasure generation network; Determining the prediction accuracy of the second training sample based on the prediction label and the actual label of the second training sample; determining a first loss value of a plurality of loss functions based on a first training sample, the second training sample, and current weights of the plurality of loss functions of the first generation model, the first training sample being a real sample; performing iterative training on the weights of the multiple loss functions based on the prediction accuracy and the first loss value until the weights of the multiple loss functions meet a first convergence condition, determining that the iterative training of the weights of the multiple loss functions is finished, wherein the multiple loss functions of the first generation model comprise loss functions respectively corresponding to multiple tasks of the first generation model, and the first convergence condition is that the difference value of the weights of the multiple loss functions in the two previous and subsequent iterative training processes is smaller than a fourth preset threshold value, or the difference value between the loss values of the two previous and subsequent iterative training of the first generation model is smaller than a fifth preset threshold value, or the iteration times of the first generation model reach a first preset iteration times; and determining a first target model based on a first generation model corresponding to the weights of the multiple loss functions at the end of iterative training, wherein the first target model is an image processing model or is used for voice recognition.
  2. 2. The method of claim 1, wherein the iteratively training weights for the plurality of loss functions based on the prediction accuracy and the first loss value comprises: Acquiring the weight of the first loss value; Determining a total loss value based on the weight of the first loss value, and the prediction accuracy; And carrying out iterative training on the weights of the plurality of loss functions based on the total loss value.
  3. 3. The method of claim 1, wherein the determining a first loss value for the plurality of loss functions based on current weights for the plurality of loss functions for the first training sample, the second training sample, and the first generation model comprises: separating at least one loss function from the plurality of loss functions based on the complexity of the plurality of loss functions to obtain a target loss function; Determining a first loss value of the target loss function based on the first training sample, the second training sample, and the current weight of the target loss function; the iteratively training weights of the plurality of loss functions based on the prediction accuracy and the first loss value, comprising: determining a second loss value for the at least one loss function based on the first training sample, the second training sample, and the current weight of the at least one loss function; and iteratively training the weight of the target loss function based on the first loss value, the prediction accuracy and the second loss value.
  4. 4. The method of claim 1, wherein the determining the predictive label of the second training sample by the first discriminant model in the countermeasure generation network comprises: iteratively training the first discriminant model based on a first training sample and the second training sample until the first discriminant model meets a second convergence condition; And determining a prediction label of the second training sample based on the first discrimination model obtained by training.
  5. 5. The method of claim 4, wherein iteratively training the first discriminant model based on a first training sample and the second training sample comprises: Marking the actual label of the first training sample as a true label and the actual label of the second training sample as a false label; and performing iterative training on the first discriminant model based on the actual label of the first training sample and the actual label of the second training sample.
  6. 6. The method of claim 4, wherein iteratively training the first discriminant model based on a first training sample and the second training sample comprises: marking the actual label of the first training sample as a true label, the first actual label of the second training sample as a false label and the second actual label of the second training sample as a true label; And performing iterative training on the first discriminant model based on the actual label of the first training sample, the first actual label of the second training sample and the second actual label.
  7. 7. The method of claim 1, wherein prior to generating the second training samples based on the first generation model in the countermeasure generation network, the method further comprises: and performing iterative training on the network parameters of the first generation model based on the first training sample until the network parameters of the first generation model meet a third convergence condition, and executing the step of generating a second training sample based on the first generation model in the countermeasure generation network.
  8. 8. The method of claim 1, wherein determining the first target model based on the first generation model corresponding to the weights of the plurality of loss functions at the end of the iterative training comprises: Determining a second target model based on a first generation model corresponding to the weights of the plurality of loss functions at the end of iterative training; And carrying out iterative training on the network parameters of the second target model based on a third training sample until the network parameters of the second target model meet a fourth convergence condition to obtain the first target model, wherein the third training sample is a training sample for training the second target model.
  9. 9. The method of claim 8, wherein iteratively training network parameters of the second target model based on the third training samples comprises: Based on the third training sample, performing iterative training on the current network parameters of the second target model, or Initializing network parameters of the second target model, and performing iterative training on the network parameters after the second target model is initialized based on the third training sample.
  10. 10. A model training apparatus, the apparatus comprising: The generation module is used for generating a second training sample based on a first generation model in the countermeasure generation network, wherein the second training sample is a pseudo sample; a first determining module, configured to determine a prediction label of the second training sample through a first discriminant model in the countermeasure generation network; The first training module is used for determining the prediction accuracy of the second training sample based on the prediction label and the actual label of the second training sample; determining a first loss value of a plurality of loss functions based on a first training sample, the second training sample, and current weights of the plurality of loss functions of the first generation model, the first training sample being a real sample; performing iterative training on the weights of the multiple loss functions based on the prediction accuracy and the first loss value until the weights of the multiple loss functions meet a first convergence condition, determining that the iterative training of the weights of the multiple loss functions is finished, wherein the multiple loss functions of the first generation model comprise loss functions respectively corresponding to multiple tasks of the first generation model, and the first convergence condition is that the difference value of the weights of the multiple loss functions in the two previous and subsequent iterative training processes is smaller than a fourth preset threshold value, or the difference value between the loss values of the two previous and subsequent iterative training of the first generation model is smaller than a fifth preset threshold value, or the iteration times of the first generation model reach a first preset iteration times; And the second determining module is used for determining a first target model based on the generated model corresponding to the weights of the multiple loss functions at the end of iterative training, wherein the first target model is an image processing model or is used for carrying out voice recognition.
  11. 11. An electronic device comprising a processor and a memory, wherein the memory has stored therein at least one program code that is loaded and executed by the processor to implement the model training method of any of claims 1 to 9.
  12. 12. A computer readable storage medium having stored therein at least one program code, the at least one program code being loaded and executed by a processor to implement the model training method of any of claims 1 to 9.
  13. 13. A computer program product, characterized in that at least one program code is stored in the computer program product, which is loaded and executed by a processor for implementing the model training method according to any of the claims 1 to 9.

Description

Model training method, device, equipment, storage medium and program product Technical Field The present application relates to the field of machine learning technologies, and in particular, to a model training method, apparatus, device, storage medium, and program product. Background The model training is essentially a process of minimizing the loss function, and the model may execute a plurality of tasks, each task corresponds to a loss function, each loss function corresponds to a weight, and the size of the weight directly influences the use effect of the model, so how to determine the weight corresponding to each loss function becomes a problem to be solved urgently. In the related technology, related personnel firstly set initial weights of the loss functions according to experience, then train a model for a period of time, adjust the weights of the loss functions according to the training effect according to experience, and then continue training until the training effect meets the requirement, so as to obtain a model obtained by final training. However, the related art method is that the related personnel adjust the weight of the loss function empirically, and this method cannot guarantee the optimality of the determined weight, so that the use effect of the model obtained by the final training is poor. Disclosure of Invention The embodiment of the application provides a model training method, device, equipment, storage medium and program product, which can improve the use effect of a model. The technical scheme is as follows: in one aspect, a model training method is provided, the method comprising: Generating a second training sample based on a first generation model in the countermeasure generation network, the second training sample being a dummy sample; determining a predictive label of the second training sample through a first discriminant model in the countermeasure generation network; Iteratively training weights of a plurality of loss functions of the first generation model based on the prediction label and the actual label of the second training sample until the weights of the plurality of loss functions meet a first convergence condition, and determining that the iterative training of the weights of the plurality of loss functions is finished; a first target model is determined based on weights of a plurality of penalty functions at the end of the iterative training. In another aspect, there is provided a model training apparatus, the apparatus comprising: The generation module is used for generating a second training sample based on a first generation model in the countermeasure generation network, wherein the second training sample is a pseudo sample; a first determining module, configured to determine a prediction label of the second training sample through a first discriminant model in the countermeasure generation network; The first training module is used for carrying out iterative training on the weights of the multiple loss functions of the first generation model based on the predicted label and the actual label of the second training sample until the weights of the multiple loss functions meet a first convergence condition, and determining that the iterative training of the weights of the multiple loss functions is finished; And the second determining module is used for determining the first target model based on the weights of the plurality of loss functions at the end of iterative training. In another aspect, an electronic device is provided, the electronic device including a processor and a memory, the memory storing at least one program code, the at least one program code loaded and executed by the processor to implement the model training method described above. In another aspect, a computer readable storage medium having at least one program code stored therein is provided, the at least one program code loaded and executed by a processor to implement the model training method described above. In another aspect, a computer program product is provided, in which at least one program code is stored, which is loaded and executed by a processor to implement the model training method described above. The technical scheme provided by the embodiment of the application has the beneficial effects that: The embodiment of the application provides a model training method, which is characterized in that a pseudo sample is generated based on a first generation model, a prediction label of the pseudo sample is determined through a first discrimination model, and the weights of a plurality of loss functions of the first generation model are iteratively trained based on the prediction label and an actual label of the pseudo sample. It is to be understood that both the foregoing general description and the following detailed description are exemplary and explanatory only and are not restrictive of the disclosure. Drawings FIG. 1 is a schematic diagram of an implementation environment of a mod