Search

CN-122021715-A - Space-time two-channel mask interpretation method and related equipment for dynamic weighted graph neural network

CN122021715ACN 122021715 ACN122021715 ACN 122021715ACN-122021715-A

Abstract

The embodiment of the application provides a space-time two-channel mask interpretation method and related equipment for a dynamic weighted graph neural network, belonging to the field of artificial intelligence and machine learning. The method comprises the steps of obtaining a pre-trained dynamic graph neural network model, a dynamic weighted graph sequence and original prediction thereof, initializing a spatial importance mask matrix and a time importance mask vector, adjusting two masks in a combined mode through iterative optimization under the condition of fixed model parameters, generating a gating signal by using the masks in each iteration, modulating original side weights to construct a disturbance graph, obtaining disturbance prediction through model forward calculation, reversely optimizing the masks by minimizing a total loss function fusing fidelity loss, sparsity loss and time smoothness loss, and generating importance interpretation tensors fusing space-time dimensions based on the optimized masks. The method realizes the fine granularity, high fidelity and smooth visual interpretation of the dynamic weighted graph decision through the decoupled space-time mask collaborative learning and the weight perception disturbance.

Inventors

  • CHEN HE
  • XIAO JIAQING
  • WANG YONG
  • ZHANG YE
  • GUAN LONGZHOU
  • LI XIAOLI

Assignees

  • 华南理工大学

Dates

Publication Date
20260512
Application Date
20251225

Claims (10)

  1. 1. A space-time two-channel mask interpretation method for a dynamic weighted graph neural network is characterized by comprising the following steps: Acquiring a pre-trained dynamic graph neural network model, a dynamic weighted graph sequence to be explained and an original prediction result of the model on the sequence, wherein the dynamic weighted graph sequence comprises graph data of T time steps, and the graph data of each time step comprises a node connection strength matrix; initializing a learnable spatial importance mask matrix and a learnable temporal importance mask vector; Under the condition of fixing the parameters of the dynamic graph neural network model, the spatial importance mask matrix and the temporal importance mask vector are adjusted in a combined mode through iterative optimization; wherein each iterative optimization comprises: Calculating a gating signal corresponding to each edge in the dynamic weighted graph sequence at each time step based on the current spatial importance mask matrix and the temporal importance mask vector; modulating the connection strength of the corresponding original node in the dynamic weighted graph sequence by using the gating signal to generate a disturbed dynamic graph sequence; Inputting the disturbed dynamic graph sequence into the fixed dynamic graph neural network model to obtain a disturbance prediction result; calculating a fidelity loss based at least on the difference between the disturbance prediction result and the original prediction result; Calculating gradients thereof with respect to the spatial importance mask matrix and the temporal importance mask vector from the fidelity loss, and updating the spatial importance mask matrix and the temporal importance mask vector; Based on the optimized spatial importance mask matrix and the time importance mask vector, an importance interpretation tensor fusing space-time dimensions is generated, and element values in the importance interpretation tensor represent the contribution degree of corresponding node connection to the original prediction result in corresponding time steps.
  2. 2. The method of claim 1, wherein the computing a gating signal at each time step corresponding to each edge in the dynamic weighted graph sequence based on the current spatial importance mask matrix and the temporal importance mask vector comprises: For nodes And node In the first place Connection of time steps, gating signals thereof Calculated by the following formula: Wherein, the For the corresponding element in the spatial importance mask matrix, For the corresponding element in the temporal importance mask vector, Is a Sigmoid function.
  3. 3. The method of claim 2, wherein modulating the corresponding original node connection strength in the dynamic weighted graph sequence with the gating signal comprises: the gating signal is processed Original node connection strength corresponding to the dynamic weighted graph sequence Multiplying to obtain the connection strength after disturbance The method comprises the following steps: 。
  4. 4. The method of claim 1, wherein the loss function in the iterative optimization further comprises a sparsity loss calculated based on the spatial importance mask matrix and the L1 norm of the temporal importance mask vector, and/or, The loss function in the iterative optimization further includes a temporal smoothness loss for constraining a smoothness of a change of the temporal importance mask vector between adjacent temporal steps.
  5. 5. The method of claim 4, wherein the temporal smoothness is lost Calculated by the following formula: Wherein, the Representing the temporal importance mask vector as being at the first The value of the time step.
  6. 6. The method of claim 4, wherein the total loss of iterative optimization For loss of fidelity Loss of sparsity And loss of temporal smoothness Is a weighted sum of (1), namely: Wherein, the And The weight coefficient is preset; said calculating a gradient from said fidelity loss, in particular said calculating said total loss Gradients with respect to the spatial importance mask matrix and the temporal importance mask vector.
  7. 7. The method of claim 2, wherein the generating importance interpretation tensors of the fused spatio-temporal dimensions Specifically, the method is calculated by the following formula: Wherein, the And For the optimized spatial importance mask matrix and temporal importance mask vector, The original node connection strength.
  8. 8. The method of claim 1, further comprising the step of visually interpreting the importance interpretation tensor, comprising at least one of: Aggregating the importance interpretation tensors along a time dimension to generate a global spatial importance map; aggregating the importance interpretation tensors along the dimension of the space node to generate a global time importance curve; and extracting a curve of the importance of the connection of the specific node in the importance interpretation tensor along with the time to generate a dynamic contribution graph of the specific connection.
  9. 9. An electronic device comprising a memory storing a computer program and a processor implementing the method of any of claims 1 to 8 when the computer program is executed by the processor.
  10. 10. A computer readable storage medium storing a computer program, characterized in that the computer program when executed by a processor implements the method of any one of claims 1 to 8.

Description

Space-time two-channel mask interpretation method and related equipment for dynamic weighted graph neural network Technical Field The application relates to the field of artificial intelligence and machine learning, in particular to a space-time two-channel mask interpretation method and related equipment for a dynamic weighted graph neural network. Background The graphic neural network and the dynamic variant thereof have become core technology for processing the structure data of the time sequence diagrams such as the electroencephalogram, the traffic flow, the social network and the like. In these applications, the edges of the graph typically carry important weighting information (i.e., weighted graph), such as in a brain function network, the edge weights represent the functional connection strength between the brain regions. However, the "black box" nature of the dynamic weighted graph neural network severely hampers its application in high-reliability scenarios such as medical diagnostics, financial wind control, and the like. The existing interpretability method has obvious defects when facing the dynamic weighted graph neural network. The closest prior art, such as GNNExplainer and its naive extension to dynamic graphs, suffers from the following fundamental drawbacks: 1) Neglecting the contribution of the edge weight, the existing method is mainly designed aiming at the non-weighted graph, or the edge weight is treated as a common input characteristic, and the contribution degree of the key attribute of the connection strength to the model decision cannot be definitely quantized. 2) And when the dynamic graph is processed, an interpretation algorithm is independently operated in each time step, so that interpretation results are fragmented, time consistency is lacking, and unified influence of a key mode on space topology and time evolution cannot be cooperatively revealed. 3) Lack of time continuity constraints-independent optimization results in unreasonable and intense jitter in interpretation results of adjacent time steps, which is contrary to the continuity of realistic physical processes such as brain activity, traffic flow, etc. 4) The calculation efficiency is low, the whole time sequence needs to be subjected to independent iterative optimization for many times, and the calculation cost is high. Disclosure of Invention The embodiment of the application mainly aims to provide a space-time two-channel mask interpretation method, electronic equipment, a storage medium and a program product for a dynamic weighted graph neural network, which can generate fused and fine-granularity space-time importance interpretation, carry out primary support edge weight quantization, generate smooth and reliable interpretation results through time sequence consistency constraint, and realize end-to-end efficient calculation and visual visualization. In order to achieve the above objective, an aspect of the embodiments of the present application provides a space-time two-channel mask interpretation method for a dynamic weighted graph neural network, where the method includes: Acquiring a pre-trained dynamic graph neural network model, a dynamic weighted graph sequence to be explained and an original prediction result of the model on the sequence, wherein the dynamic weighted graph sequence comprises graph data of T time steps, and the graph data of each time step comprises a node connection strength matrix; initializing a learnable spatial importance mask matrix and a learnable temporal importance mask vector; Under the condition of fixing the parameters of the dynamic graph neural network model, the spatial importance mask matrix and the temporal importance mask vector are adjusted in a combined mode through iterative optimization; wherein each iterative optimization comprises: Calculating a gating signal corresponding to each edge in the dynamic weighted graph sequence at each time step based on the current spatial importance mask matrix and the temporal importance mask vector; modulating the connection strength of the corresponding original node in the dynamic weighted graph sequence by using the gating signal to generate a disturbed dynamic graph sequence; Inputting the disturbed dynamic graph sequence into the fixed dynamic graph neural network model to obtain a disturbance prediction result; calculating a fidelity loss based at least on the difference between the disturbance prediction result and the original prediction result; Calculating gradients thereof with respect to the spatial importance mask matrix and the temporal importance mask vector from the fidelity loss, and updating the spatial importance mask matrix and the temporal importance mask vector; Based on the optimized spatial importance mask matrix and the time importance mask vector, an importance interpretation tensor fusing space-time dimensions is generated, and element values in the importance interpretation tensor represent the contri