CN-122020303-A - Sample weight real-time self-adaptive statistical adjustment method for classification model
Abstract
The invention discloses a sample weight real-time self-adaptive statistical adjustment method for a classification model, which comprises the following steps of obtaining sample data of a current training batch, inputting the sample data into a main classification model, processing the sample data by the main classification model to generate prediction result data, receiving the prediction result data and the sample data by a real-time statistical tracking module, updating sample level statistics and batch level statistics based on the prediction result data, receiving the sample level statistics and the batch level statistics by a weight decision network, calculating a real-time weight value of each sample, updating parameters of the main classification model by a weighted loss function, updating parameters of the main classification model by a back propagation algorithm, updating parameters of a weight decision network according to updating feedback of the main classification model, and thus being capable of continuously optimizing feature learning of few types of samples and difficult samples, thereby obviously improving classification balance among various categories while keeping overall high accuracy.
Inventors
- Leng Zhengyan
- JIANG YUHAN
Assignees
- 南京农业大学
Dates
- Publication Date
- 20260512
- Application Date
- 20260129
Claims (10)
- 1. The sample weight real-time self-adaptive statistical adjustment method for the classification model is characterized by comprising the following steps of: S1, acquiring sample data of a current training batch, inputting the sample data into a main classification model, and processing the sample data by the main classification model to generate prediction result data; S2, a real-time statistics tracking module receives the prediction result data and the sample data, and updates sample level statistics and batch level statistics based on the prediction result data; S3, the weight decision network receives the sample level statistic and the batch level statistic, and calculates a real-time weight value of each sample; s4, updating parameters of the main classification model by using a weighted loss function, and updating the parameters of the main classification model through a back propagation algorithm; and S5, updating parameters of the weight decision network according to the updating feedback of the main classification model.
- 2. The method for real-time adaptive statistical adjustment of sample weights for classification models of claim 1, wherein in S2, the process of updating sample level statistics and lot level statistics by the real-time statistical tracking module comprises: s201, calculating a current loss value of each sample by adopting a cross entropy loss function, wherein the current loss value is obtained by comparing the predicted result data with the sample data; s202, updating a historical loss value sequence of each sample, adding the current loss value to the historical loss value sequence of the corresponding sample, and maintaining a sequence window with a fixed length; S203, calculating a prediction consistency index and a prediction uncertainty index of each sample, wherein the prediction consistency index is obtained by calculating cosine similarity between the prediction probability distribution of the current iteration and the prediction probability distribution of the previous iteration; S204, calculating sample number distribution of each category in the current batch, wherein the sample number distribution is obtained by counting and normalizing the sample number of each category, calculating an average loss value of the current batch, and updating batch level statistics; And S205, storing the updated sample level statistic and the updated batch level statistic into a memory, and providing real-time data input for S3.
- 3. The method for real-time adaptive statistical adjustment of sample weights for classification models according to claim 2, wherein in S3, the process of calculating real-time weight values by the weight decision network comprises: S301, constructing a multidimensional input feature vector comprising sample level features and batch level features; S302, inputting the multidimensional input feature vector into a weight decision network, wherein the weight decision network comprises an input layer, a plurality of hidden layers and an output layer, the hidden layers perform nonlinear transformation by using an activation function, and the output layer maps output values to the output values by using a Sigmoid function Obtaining an initial weight value of each sample in the interval; S303, carrying out normalization processing on the initial weight values, wherein the normalization processing is obtained by dividing the initial weight value of each sample by the sum of the initial weight values of all samples and multiplying the sum by the batch size; S304, outputting the normalized real-time weight value to S4, and calculating a weighted loss function.
- 4. The method for real-time adaptive statistical adjustment of sample weights for classification models according to claim 3, wherein in S4, updating the primary classification model parameters with the weighted loss function comprises: S401, calculating a loss value of each sample by adopting a cross entropy loss function, wherein the loss value is obtained by comparing the predicted result data with sample data; s402, weighting the loss value of each sample according to the real-time weight value to obtain a weighted loss value; S403, summing the weighted loss values of all the samples to obtain a weighted total loss; s404, calculating the gradient of the weighted total loss to the main classification model parameters through a back propagation algorithm, and updating the main classification model parameters by using an optimization algorithm; and S405, transmitting the weighted total loss and the parameter gradient as update feedback to S5, and updating the parameters of the weight decision network.
- 5. The method for real-time adaptive statistical adjustment of sample weights for classification models as set forth in claim 4, wherein in S5, said updating the weight decision network parameters comprises: s501, receiving updated feedback comprising the change rate of the weighted total loss and the parameter gradient of the main classification model from S4; s502, calculating a loss function of the weight decision network, wherein the loss function takes a product of a negative value of the change rate of the weighted total loss and a parameter gradient norm as a loss value based on the change rate of the weighted total loss and the norm of the parameter gradient of the main classification model; S503, calculating the gradient of the loss function to the weight decision network parameter through a back propagation algorithm, and applying a conditional gradient stopping mechanism in the back propagation process, wherein the conditional gradient stopping mechanism dynamically decides whether to allow the gradient to continue back propagation to a part of the bottom network layer of the main classification model according to the smoothness of the performance improvement of the main classification model; S504, updating weight decision network parameters by adopting a random gradient descent optimization algorithm with a smaller learning rate; s505, storing the updated weight decision network parameters for weight calculation of the next training iteration.
- 6. The method for real-time adaptive statistical adjustment of sample weights for classification models of claim 1, wherein said adaptive target weight vector at least corresponds to a class balance target, a difficult sample mining target, and a noise robustness target; For a class balance target, the weight decision network dynamically adjusts weights based on sample number distribution in the batch level statistics; For a difficult sample mining target, the weight decision network identifies samples with continuous high loss and high prediction consistency as difficult samples based on a historical loss value sequence and a prediction consistency index in sample level statistics, and adds weight values to strengthen classification boundary learning; for a noise robustness target, the weight decision network identifies samples with low prediction consistency and large loss fluctuation as noise samples based on a prediction consistency index in sample level statistics and the fluctuation of a historical loss value sequence; the multi-target weight optimization strategy learns the balance weights among different targets through a hidden layer of a weight decision network, and realizes self-adaptive weight distribution.
- 7. The sample weight real-time self-adaptive statistical adjustment method for the classification model according to claim 1, wherein the real-time statistical tracking module, the weight decision network and the main classification model adopt an alternate collaborative training mechanism, the alternate collaborative training mechanism comprises the steps that when the main classification model is updated, parameters of the weight decision network are kept fixed, when the weight decision network is updated, the parameters of the main classification model are kept fixed, the alternate collaborative training is realized through a double time scale updating rule, the main classification model adopts a conventional updating rhythm, and the weight decision network adopts a smaller learning rate to update.
- 8. The method for real-time adaptive statistical adjustment of sample weights for classification models of claim 1, wherein the update frequency of the sample level statistics and the batch level statistics is synchronized with training iterations, wherein the sample level statistics are updated after each training iteration using the prediction result and the sample data of the current batch, the batch level statistics are updated after each training iteration based on all sample data of the current batch, and the real-time statistical tracking module maintains historical data through a sliding window mechanism.
- 9. The method for real-time adaptive statistical adjustment of sample weights for classification models according to claim 1, wherein the weight decision network is a lightweight neural network and comprises an input layer, two hidden layers and an output layer, wherein the number of nodes of the input layer is equal to the dimension of a multidimensional input feature vector, the number of nodes of the hidden layers is 64 and 32 respectively, a ReLU activation function is used, the number of nodes of the output layer is 1, and a Sigmoid activation function is used.
- 10. The method for real-time adaptive statistical adjustment of sample weights for classification models according to claim 1, wherein the real-time statistical tracking module updates statistics using incremental computation, computation of the weight decision network is performed in parallel with forward propagation of the main classification model, and the weighted loss function computation employs vectorization operation, and GPU acceleration processing is utilized.
Description
Sample weight real-time self-adaptive statistical adjustment method for classification model Technical Field The invention relates to a sample weight adjusting method, in particular to a sample weight real-time self-adaptive statistical adjusting method for a classification model. Background In machine learning classification tasks, the assignment of sample weights is a key factor affecting model performance. The ideal sample weighting strategy can guide the model to focus on samples with more information, thereby improving generalization capability. However, the existing data often accompanies complex problems such as unbalanced category, labeling noise, dynamic change of data distribution, etc., which makes it difficult for the conventional classification algorithm, which assumes that all samples contribute equally, to achieve an ideal effect. Disclosure of Invention The invention overcomes the defects of the prior art and provides a sample weight real-time self-adaptive statistical adjustment method for a classification model. In order to achieve the purpose, the technical scheme adopted by the invention is that the sample weight real-time self-adaptive statistical adjustment method for the classification model comprises the following steps: S1, acquiring sample data of a current training batch, inputting the sample data into a main classification model, and processing the sample data by the main classification model to generate prediction result data; S2, a real-time statistics tracking module receives the prediction result data and the sample data, and updates sample level statistics and batch level statistics based on the prediction result data; S3, the weight decision network receives the sample level statistic and the batch level statistic, and calculates a real-time weight value of each sample; s4, updating parameters of the main classification model by using a weighted loss function, and updating the parameters of the main classification model through a back propagation algorithm; and S5, updating parameters of the weight decision network according to the updating feedback of the main classification model. In a preferred embodiment of the present invention, in S2, the process of updating the sample level statistic and the lot level statistic by the real-time statistics tracking module includes: s201, calculating a current loss value of each sample by adopting a cross entropy loss function, wherein the current loss value is obtained by comparing the predicted result data with the sample data; s202, updating a historical loss value sequence of each sample, adding the current loss value to the historical loss value sequence of the corresponding sample, and maintaining a sequence window with a fixed length; S203, calculating a prediction consistency index and a prediction uncertainty index of each sample, wherein the prediction consistency index is obtained by calculating cosine similarity between the prediction probability distribution of the current iteration and the prediction probability distribution of the previous iteration; S204, calculating sample number distribution of each category in the current batch, wherein the sample number distribution is obtained by counting and normalizing the sample number of each category, calculating an average loss value of the current batch, and updating batch level statistics; And S205, storing the updated sample level statistic and the updated batch level statistic into a memory, and providing real-time data input for S3. In a preferred embodiment of the present invention, in S3, the process of calculating the real-time weight value by the weight decision network includes: S301, constructing a multidimensional input feature vector comprising sample level features and batch level features; S302, inputting the multidimensional input feature vector into a weight decision network, wherein the weight decision network comprises an input layer, a plurality of hidden layers and an output layer, the hidden layers perform nonlinear transformation by using an activation function, and the output layer maps output values to the output values by using a Sigmoid function Obtaining an initial weight value of each sample in the interval; S303, carrying out normalization processing on the initial weight values, wherein the normalization processing is obtained by dividing the initial weight value of each sample by the sum of the initial weight values of all samples and multiplying the sum by the batch size; S304, outputting the normalized real-time weight value to S4, and calculating a weighted loss function. In a preferred embodiment of the present invention, in S4, the process of updating the parameters of the main classification model using the weighted loss function includes: S401, calculating a loss value of each sample by adopting a cross entropy loss function, wherein the loss value is obtained by comparing the predicted result data with sample data; s402, weighting the l