CN-115577803-B - Federal learning method and system for robustness to mixed noise
Abstract
The invention provides a federal learning method for robustness to mixed noise, which comprises the steps of sending local measurement model parameters updated by a client according to local training data and subjective logic loss functions to a server, obtaining global measurement model parameters calculated by the server according to the local measurement model parameters updated by the client and other clients and corresponding sample data amounts, calculating subjective logic loss and local Gaussian mixture distribution of each training sample by the client according to the global measurement model parameters and the local training data and sending the subjective logic loss and the local Gaussian mixture distribution to the server, obtaining interval threshold values calculated by the server according to a plurality of local Gaussian mixture distribution, carrying out mixed noise recognition on the training data based on the interval threshold values and the local Gaussian mixture distribution, and further respectively filtering and correcting recognized open-set noise and closed-set noise. The invention also provides a federal learning system which is robust to the mixed noise and a data processing device for federal learning.
Inventors
- CHEN YIQIANG
- ZENG BIXIAO
- YANG XIAODONG
- YU HANCHAO
Assignees
- 中国科学院计算技术研究所
Dates
- Publication Date
- 20260505
- Application Date
- 20221012
Claims (8)
- 1. A federal learning method robust to mixed noise, comprising: Step 1, constructing a local measurement model at a client, and training the local measurement model for a designated round by using local training data to complete the trained local measurement model Model parameters of (2) For local measurement model parameter sent to server end, and for the model parameter And the first sample quantity of the local training data is sent to the server side, the local measurement model Taking the subjective logic function as a loss function, the loss function of certain local training data i: Wherein, the For dirichlet distribution Is used to determine the coefficient of the coefficient, , Characteristic data at i for activation function RELU with respect to the local metrology model Output result on, represent support of the first Evidence of class labels, The potential of dirichlet distribution is represented, n represents the client serial number, and subscript b represents the metric model; Step2, using all the model parameters at the server And the first sample size obtains global metric model parameters; step 3, obtaining subjective logic loss values of each local training data by the global measurement model parameters and the local training data at the client, so as to obtain local Gaussian mixture distribution of all the local training data by fitting; Step 4, generating local Gaussian mixture discrete distribution by using all the local Gaussian mixture distribution at the server, and polymerizing to obtain global Gaussian mixture distribution, obtaining KS distance between the local Gaussian mixture discrete distribution and the global Gaussian mixture distribution, and obtaining a distinguishing threshold according to the KS distance; step 5, selecting a training data set from the local training data according to the distinguishing threshold at the client; Step 6, training the local classification model by using the training data set, carrying out label correction on the local training data by using the local classification model, and sending the local classification model parameters and the corrected second sample size of the local training data to the server; Step 7, obtaining global classification model parameters from the server by using all the local classification model parameters and the second sample size; Step 8, updating the local classification model to a global classification model by the global classification model parameters at the client; and (3) repeating the steps 6-8 until the global classification model of the client converges, and taking the global classification model at the moment as the final classification model of the client.
- 2. The federal learning method according to claim 1, wherein in step 2, the global metric model parameters are obtained by the following formula : Wherein, the Representing the amount of local training data, Representing the sum of the first sample amounts of all the clients, Representing the client serial number and, Representing the number of clients.
- 3. The federal learning method according to claim 1, wherein in the step 4, the discrimination threshold includes a first discrimination threshold lambda I and a second discrimination threshold lambda II ,λ I as inflection values of the KS distance to divide a clean data subset and an open set noise data subset in the training data set, lambda II is a minimum value of the KS distance to divide an open set noise data subset and a closed set noise data subset in the training data set, Wherein, the Is a discrete global Gaussian mixture distribution Is used to determine the cumulative distribution function of (1), Is a discrete local Gaussian mixture distribution Is a conditional probability distribution of (a), 。
- 4. A federal learning system robust to mixed noise, comprising: The model construction module is arranged at the client and used for constructing a local measurement model, and training the local measurement model in a designated round by using local training data to complete the trained local measurement model Model parameters of (2) For local measurement model parameter sent to server end, and for the model parameter And the first sample quantity of the local training data is sent to the server side, the local measurement model Taking the subjective logic function as a loss function, the loss function of certain local training data i: Wherein, the For dirichlet distribution Is used to determine the coefficient of the coefficient, , Characteristic data at i for activation function RELU with respect to the local metrology model Output result on, represent support of the first Evidence of class labels, The potential of dirichlet distribution is represented, n represents the client serial number, and subscript b represents the metric model; the measurement model aggregation module is arranged at the server and is used for using all the model parameters And the first sample size obtains global metric model parameters; The local Gaussian mixture distribution generation module is arranged at the client and is used for obtaining subjective logic loss values of each local training data according to the global measurement model parameters and the local training data, so that local Gaussian mixture distribution of all the local training data is obtained through fitting; The threshold calculation module is arranged at the server and is used for generating local Gaussian mixture discrete distribution by using all the local Gaussian mixture distribution, polymerizing to obtain global Gaussian mixture distribution, obtaining KS distance between the local Gaussian mixture discrete distribution and the global Gaussian mixture distribution, and obtaining a distinguishing threshold according to the KS distance; the noise identification module is arranged at the client and is used for selecting a training data set from the local training data according to the distinguishing threshold; the local classification model is used for carrying out label correction on the local training data, and the local classification model parameters and the corrected second sample size of the local training data are sent to the server; The classification model aggregation module is arranged at the server and is used for acquiring global classification model parameters according to all the local classification model parameters and the second sample size; The model updating module is arranged at the client and is used for updating the local classification model into a global classification model by using the global classification model parameters; The iteration convergence module is arranged at the client and is used for repeatedly and sequentially calling the noise correction module, the classification model aggregation module and the model updating module until the global classification model of the client converges, and the global classification model at the moment is taken as the final classification model of the client.
- 5. The federal learning system according to claim 4, wherein the metrology model aggregation module obtains global metrology model parameters by the following formula : Wherein, the Representing the amount of local training data, Representing the sum of the first sample amounts of all the clients, Representing the client serial number and, Representing the number of clients.
- 6. The federal learning system according to claim 4, wherein the threshold calculation module comprises: a first threshold calculation module for dividing the clean data subset and the open set noise data subset in the training data set with the inflection value of the KS distance as a first discrimination threshold lambda I ,λ I , ; A second threshold calculation module for dividing the open set noise data subset and the closed set noise data subset in the training data set by a second discrimination threshold lambda II ,λ II with the minimum value of the KS distance, ; Wherein, the Is a discrete global Gaussian mixture distribution Is used to determine the cumulative distribution function of (1), Is a discrete local Gaussian mixture distribution Is a conditional probability distribution of (a), 。
- 7. A computer readable storage medium storing computer executable instructions which, when executed, implement the federal learning method robust against mixed noise according to any of claims 1 to 3.
- 8. A data processing apparatus comprising the computer readable storage medium of claim 7 and a processor that retrieves and executes computer executable instructions in the computer readable storage medium to perform federal learning robust to mixed noise when the data processing apparatus is acting as a client or a server.
Description
Federal learning method and system for robustness to mixed noise Technical Field The invention relates to the technical field of machine learning, in particular to a federal learning method and a federal learning system for robustness to mixed noise. Background With the development of distributed machine learning and big data analysis, federal learning is proposed as a novel distributed machine learning framework, supporting a collaborative training model while guaranteeing data privacy for a plurality of clients (institutions). In the model training process, only intermediate parameters are exchanged between the server and the clients, and each client does not need to upload any original data. Under an actual federal learning scene, the addition of a plurality of clients brings more knowledge, but increases the risk of tag noise at the same time, and influences the performance of the model. Federal learning is more challenging because federal learning "data is available and invisible". In a real scene, tag noise is largely divided into two categories, open set noise and closed set noise. For example, in an AI pneumonia assisted diagnosis task, a chest CT marked as pneumonia belongs to an open set noise sample, the true class of which does not belong to the classification task, while a normal lung CT marked as pneumonia belongs to a closed set noise, the true class (normal) of which is contained in the classification task. In general, open set noise is suitable for methods such as noise filtering, while closed set noise is suitable for a family of methods of label reasoning due to its reusability, reducing information loss. Due to the subjectivity of labeling, labeling deviation and other reasons, the noise degree and noise components of each client have heterogeneity, and in federal learning, it becomes important to discriminate noise types for different clients and then process the different clients respectively. Existing techniques for solving the tag noise problem often rely on a noise type assumption. For open set noise, noise filtering methods such as sample screening and noise weight reduction are often adopted because the original features of the open set noise are irrelevant to classification tasks, and for closed set noise, noise correction methods such as label reasoning and loss function correction are often adopted because the original features of the closed set noise belong to classification tasks and have higher reuse value. However, the actual federal scenario is more complex, and the client multi-source heterogeneity leads to the fact that the tag noise is very likely to appear in the form of a mixture of open and closed sets, and the noise level is different from client to client. The existing method cannot give a federal learning scheme with mixed noise robustness, and how to identify mixed noise and process the mixed noise respectively in a federal learning scene is a key problem to be solved urgently. Disclosure of Invention In view of the above problems, the present invention proposes a federal learning method robust to mixed noise, including: Step 1, constructing a local measurement model at a client, and training the local measurement model for a designated round by using local training data to complete the trained local measurement model Model parameters of (2)For local measurement model parameter sent to server end, the model parameter is calculatedAnd the first sample quantity of the local training data is sent to the server side, and the local measurement modelTaking the subjective logic function as a loss function, the loss function of certain local training data i: Wherein, the For dirichlet distributionIs used to determine the coefficient of the coefficient,,Characteristic data at i for activation function RELU with respect to the local metrology modelOutput result on, represent support of the firstEvidence of class labels,The potential of dirichlet distribution is represented, n represents the client serial number, and subscript b represents the metric model; Step2, using all the model parameters at the server The method comprises the steps of obtaining global measurement model parameters through a first sample size, obtaining subjective logic loss values of local training data through the global measurement model parameters and the local training data at a client, fitting to obtain local Gaussian mixture distribution of all the local training data, generating local Gaussian mixture discrete distribution through all the local Gaussian mixture distribution at a server, obtaining global Gaussian mixture distribution through aggregation, obtaining KS distance between the local Gaussian mixture discrete distribution and the global Gaussian mixture distribution, obtaining a distinguishing threshold according to the KS distance, selecting a training data set from the local training data according to the distinguishing threshold at the client, training the local classification mode