Search

CN-122021807-A - Federal distillation method based on self-adaptive bidirectional distillation and soft weighted polymerization

CN122021807ACN 122021807 ACN122021807 ACN 122021807ACN-122021807-A

Abstract

The invention relates to the technical field of federal learning, and provides a federal distillation method based on self-adaptive bidirectional distillation and soft weighted aggregation, which comprises the steps that a server transmits a global model and a selection state identification of a client to the client; the client builds a local teacher-student network, dynamically adjusts distillation loss weights of a feature layer and a logic layer according to received selection state identification, performs self-adaptive bidirectional knowledge distillation training, dynamically adjusts information retention rate to compress local model parameters, uploads compression parameter packets and training loss scores to a server, calculates soft aggregation weights by using a negative exponential function, and restores and weights and sums the local model parameters uploaded by each client by using the soft aggregation weights to generate a new generation global model. The invention obviously reduces the communication transmission cost, improves the model convergence precision in the non-independent co-distributed data environment, and simultaneously enhances the robustness of the system against malicious node poisoning attack.

Inventors

  • FAN LULU
  • SHAN FANGFANG
  • ZHANG SHUQIN
  • LIU YUHANG
  • CHEN ZHUO
  • MAO YIFAN
  • WANG JIAJIE
  • ZHANG XIAOYU
  • HAN YABO

Assignees

  • 中原工学院

Dates

Publication Date
20260512
Application Date
20260128

Claims (10)

  1. 1.A federal distillation method based on adaptive bi-directional distillation and soft weighted polymerization, characterized by the steps of: S1, starting federal learning, and transmitting a global model and a selection state identification of a client to the client by a server; S2, the client builds a local teacher-student network, dynamically adjusts distillation loss weights of the feature layer and the logic layer according to the received selection state identification, performs self-adaptive bidirectional knowledge distillation training, dynamically adjusts information retention rate according to training rounds to compress local model parameters, and uploads compression parameter packets and training loss scores to a server; s3, the server collects training loss scores uploaded by all clients and calculates soft aggregation weights based on scores by using a negative exponential function; s4, the server uses soft aggregation weight to restore and weight sum the local model parameters uploaded by each client to generate a new generation global model; and S5, repeating the steps S1 to S4 until the given maximum communication round is reached.
  2. 2. The federal distillation method according to claim 1, wherein the selection status indication is a client screening rate based on a predetermined server Training of the first person to participate in The individual clients are marked as core state or auxiliary state, and the generated selection state identification , wherein, Representing the state of the core being selected, Representing an unselected auxiliary state; when the initial training is performed, the server uniformly initializes the selection state identifiers of all clients participating in training to be ; The server transmits the global model parameters subjected to singular value decomposition dynamic compression and synchronously transmits a selection state identifier, the global model carries out high-intensity guidance on a client in a core state through logic layer output and intermediate characteristic information, and the guidance intensity is reduced through an attenuation coefficient for the client in an auxiliary state.
  3. 3. The method for constructing a local teacher-student network by the client based on the adaptive bi-directional distillation and soft-weighted aggregation according to claim 1 or 2, wherein the method comprises the steps that the client acquires global model parameters issued by a server, takes the global model parameters as a local teacher network, and takes original or initialized local model parameters of the client as a student network; in the training process, a two-way distillation strategy is adopted to enable a student network to simulate the prediction logic and the middle characteristic representation of a teacher network, and meanwhile, the local teacher network is reversely fine-tuned, so that the student network is better adapted to the distribution characteristics of local data; the self-adaptive bi-directional knowledge distillation training is that a client inputs local batch data to a local student network Global teacher network with frozen parameters Respectively obtaining logic layer output of student network Mapping with intermediate feature layers Logic layer output of teacher network Mapping with intermediate feature layers ; The client side selects the state identification according to the t-th round training issued by the server Judging whether the current state is in a core state or an auxiliary state, and calculating the distillation loss weight of a logic layer trained in the current round Distillation loss weight with feature layer And constructing a comprehensive loss function comprising classification loss of cross entropy, logical layer distillation loss based on KL divergence and characteristic layer alignment loss based on mean square error Finally, minimizing the comprehensive loss function by a gradient descent method To update the local model parameters.
  4. 4. The adaptive bi-directional distillation and soft weighted polymerization based federal distillation method according to claim 3, wherein the retention rate for the t-th run training is: Wherein, the method comprises the steps of, For the maximum number of rounds to be performed, In order to achieve an initial retention rate, To end the retention; The training loss score calculating method comprises the steps that a client traverses a private local data set Calculating the comprehensive loss function value of all the local samples under the current model, and defining the arithmetic average value of the comprehensive loss function values of all the samples as the training loss score of the round 。
  5. 5. The federal distillation method according to claim 4, wherein the compression parameter packet is generated by: the original model parameter matrix is: ; the truncated rank determined based on the energy threshold is: ; The compression parameter packet finally generated: ; Wherein, the Representing model parameter matrices A left singular vector matrix after SVD decomposition, Representing all rows and the first r columns of the submatrices in the left singular vector matrix U, Representing vectors of the first r elements in the singular values arranged in descending order, Representing model parameter matrices Sub-matrices of the front r rows and all columns in the right singular vector matrix V after SVD decomposition; in the form of a diagonal matrix, To diagonal matrix of singular values arranged in descending order, the total number of singular values , Representing the number of rows and columns of the model parameter matrix W, respectively, a representing the retention rate satisfying the cumulative energy ratio not lower than the target information Is set for the minimum singular value sequence number of (c), Energy representing the ith singular value, T is the transpose of the matrix.
  6. 6. The method for federal distillation according to claim 5, wherein said dynamically adjusting distillation loss weights of feature layer and logic layer based on received selection status identification comprises weighting local logic layer basis And feature layer basis weights Respectively carrying out self-adaptive dynamic adjustment to obtain: Distillation loss weight of logical layer distillation loss weight ; Distillation loss weight of feature layer ; Wherein, the For a preset attenuation coefficient The comprehensive loss function ; Wherein, the Classification loss of cross entropy for distillation temperature coefficient Wherein, the method comprises the steps of, For the one-time thermal encoding of the authentic label, The function is activated for Softmax and, Is the total number of categories; Representing the original confidence score which is output on the c-th class after the input data is transmitted forward by the local student network and is not activated by Softmax; distillation loss of logic layer Wherein, the method comprises the steps of, Representing an original confidence score of the global teacher network for input data output on class c that is not Softmax activated; feature layer alignment loss Wherein, the method comprises the steps of, Is an intermediate feature layer map of the local student network, Is the global middle feature layer mapping of the teacher network; The current model is a comprehensive loss function after the client finishes local self-adaptive bi-directional distillation training Optimizing updated local student networks Teacher network that has been back-propagated updated by gradient descent and that incorporates local data features and global Knowledge of (2); the training loss score Wherein, the method comprises the steps of, Is the first Local data set of individual clients Is a sample count of (1); Representing the first in the client local dataset A number of 1 input samples were taken, Representing input samples Corresponding real labels; representing a local dataset The comprehensive loss function value of the i1 st input sample; And uploading updated local model parameters to a server after completing preset local training rounds, wherein the updated local model parameters are student network parameters after completing multiple rounds of iterative optimization through gradient descent minimization comprehensive loss functions.
  7. 7. The federal distillation method according to any one of claims 4-6, wherein the method for calculating fraction-based soft polymerization weights using negative exponential function is The server collects training loss scores uploaded by all participating clients in the current round And calculate the sum of the negative index terms of all training loss scores to obtain the negative index accumulated sum Accumulating the sum in negative index As normalization factor, calculate the first Soft aggregate weights for individual clients; Introducing a numerical stability protection mechanism, and judging the accumulated sum of negative indexes Whether or not it is below a preset minimum threshold The method comprises the steps of judging whether a numerical value is abnormal, automatically triggering a rollback strategy, distributing weights in an average aggregation mode if the numerical value is abnormal, and dividing negative index terms of each client side by denominators to obtain normalized weights by standard Softmax function logic if the numerical value is normal.
  8. 8. The federal distillation process according to claim 7, wherein the first step is based on adaptive bi-directional distillation and soft weighted polymerization The soft aggregate weight of each client is Wherein, the Is the first Local training loss scores uploaded by individual clients, Total number of participating clients; Is a preset value stability threshold value, and the method comprises the following steps of, Negative exponential accumulation sum scoring all participating clients for current training round, and Wherein, the method comprises the steps of, For a set of clients participating in an aggregation for a current round, For client-side aggregation Training loss score for the j1 st client.
  9. 9. The federal distillation method according to any one of claims 4-6, 8, wherein the method of recovering and weighted summing local model parameters uploaded by each client is: The client packs and uploads the compression parameter packet subjected to singular value decomposition and energy truncation treatment and the local training loss score to the server, and after the server receives the compression parameter packet, the server executes the normalization operation of the negative exponential function Softmax based on the scores uploaded by all the participating clients and calculates the soft aggregation weight of each client; The client completes local training, generates compression parameter packets and training loss scores The server receives the data packets uploaded by all clients participating in aggregation and extracts training loss scores of the clients Soft aggregate weights for each client are computed by negative exponential function and Softmax normalization Restoring compression parameters of each client, and performing layer-by-layer weighted summation on the parameter matrix reconstructed by all the clients by utilizing soft aggregation weights to generate a new generation global model; The server traverses the local model parameters uploaded by each client, and for the layers in a compressed state, the singular value triples are reconstructed into a parameter matrix of the original dimension by matrix multiplication, and for the uncompressed layers, the original state is kept.
  10. 10. The federal distillation process according to claim 9, wherein the first step is based on adaptive bi-directional distillation and soft weighted polymerization Compressed data packet uploaded by individual clients Wherein, the method comprises the steps of, Is the first Training loss scores for individual clients; Represent the first The first client The layer model parameter matrix is decomposed by SVD and truncated to a left singular vector submatrix after the rank r, Represent the first The layer model parameter matrix is truncated to a singular value diagonal matrix after the rank r, Represent the first The layer model parameter matrix is truncated to a right singular vector submatrix after the rank r, A set of all layers that need SVD compression and uploading in the student network representing the client; The first after the server is restored T-th round training of each client The layer parameter matrix is ; Training global model t+1st round generated after aggregation The layer parameters are 。

Description

Federal distillation method based on self-adaptive bidirectional distillation and soft weighted polymerization Technical Field The invention relates to the technical field of artificial intelligence and distributed computation intersection, in particular to a federal distillation learning method based on self-adaptive bi-directional distillation and soft weighted aggregation, which is heterogeneous model collaborative training optimization and anti-poisoning robust aggregation in federal learning. Background In the 'universal intelligent networking' age of 5G communication, internet of things (IoT) and edge computing technology deep fusion, intelligent terminal equipment grows exponentially, and high-value data including images, voices, behavior logs and the like are continuously generated. Traditional machine learning relies on a centralized training paradigm of "data physics convergence," which not only exposes backbone networks to intolerable bandwidth and storage pressures, but also places users' personal privacy, business secrets of businesses, and national data assets into a leakage risk of centralized storage. Federal learning proposed by Google in 2016 allows a terminal device to train a model locally and upload only encrypted parameter updates through a distributed collaboration mechanism of "data motionless model", so that privacy leakage risk is remarkably reduced while data island is broken. In face of the practical challenges of limited computing power and insufficient communication bandwidth of edge devices, federal knowledge distillation has emerged as a lightweight communication and optimization paradigm. The method outputs or middle layer characteristics instead of huge full parameters through a switching model, and aims to reduce communication overhead and realize personalized learning. However, in practical applications, the client data often presents serious non-independent co-distribution characteristics, so that local models of different clients converge towards different directions in the training process, and serious "model drift" phenomenon is generated. Most of the existing federal distillation schemes adopt static distillation strategies, i.e. regardless of the current performance of the client model, the client model is forced to imitate the global teacher model with fixed weight. This lack of adaptivity mechanism often works against the fact that for "weak clients" with poor data quality or early in training, too high distillation strength can lead to loss of local features and even training divergence, while for "dominant clients" with unique data, too low weights limit their contribution to the global model. In addition, the model aggregation is used as a core hub for federal learning coordination multi-party knowledge, and the convergence accuracy and robustness of the global model are directly determined. Traditional federal averaging algorithms simply weight based primarily on the number of samples, a mechanism that assumes that all participants are honest and that the data quality is balanced. However, in an open federal network, malicious participants may initiate a poisoning attack, attempting to break the availability of the global model or post-implantation gate by uploading carefully designed anomaly parameters or tag-flipping data. To defend against such attacks, the prior art typically employs a threshold-based "hard truncation" strategy, i.e., direct culling of client updates with higher loss values or greater differences from the global model. While this approach can filter some of the noise, it also crudely discards normal clients in long tail distribution (which often hold key knowledge identifying difficult samples), resulting in a significant reduction in the generalization ability of the global model over a minority of classes of samples. The invention patent with the publication number of CN120596951A provides a balanced mode-oriented federal size model collaborative task enhancement method, which is characterized in that a big model regulator is introduced to carry out bidirectional distillation with a local small model, and data distribution inconsistency is relieved by utilizing server-side data enhancement, but the technical scheme still has significant limitations, namely firstly, the method focuses on balancing mode differences through data enhancement and regulator update, ignoring dynamic changes of a 'selection state' of a client in a training process, lacks a mechanism capable of adaptively adjusting distillation weights of a logic layer and a feature layer according to node contribution degree (core or auxiliary state), causes weak nodes to easily generate knowledge forgetting or feature collapse when being forcefully aligned, and secondly, the scheme still depends on traditional regulator update or adapter synchronization in an aggregation stage, does not introduce a negative index soft weighting strategy based on training loss scores, so that th