Search

US-20260127485-A1 - METHODS AND SYSTEMS FOR TRAINING A MACHINE LEARNING MODEL WITH GRAPH STRUCTURE INFORMATION

US20260127485A1US 20260127485 A1US20260127485 A1US 20260127485A1US-20260127485-A1

Abstract

Methods and systems for training a Machine Learning (ML) model with graph structure information are disclosed. The method performed by a server system includes accessing for each node in a graph, node features, class label, and attention score from a database, determining difficulty metric and generating sequence of node batches for training the student ML model. Each node batch includes a subset of nodes in a predefined difficulty metric range associated with each node batch. Method includes training the student ML model based on performing, iteratively, first set of operations including: selecting node batch; generating node embeddings; determining positive embedding pairs and negative embedding pairs based on the attention score; computing, by an attention-aided contrastive loss function, losses including at least an attention-aided contrastive loss; and optimizing the student model parameters based on the losses. For a subsequent iteration, a subsequent node batch is selected from the sequence.

Inventors

  • Ushmita Pareek
  • Sonia Gupta
  • Sanjay Kumar Patnala
  • Krisha Ketan SHAH
  • Siddhartha Asthana

Assignees

  • MASTERCARD INTERNATIONAL INCORPORATED

Dates

Publication Date
20260507
Application Date
20241104

Claims (20)

  1. 1 . A computer-implemented method for training a student Machine Learning (ML) model, comprising: accessing, by a server system, for each node of a set of nodes in a graph, a set of node features, a class label, and an attention score from a database associated with the server system, the class label comprising one of a predefined label and a hard label prediction, the attention score indicating an importance of each node with respect to a reference node in the graph; determining, by the server system, a difficulty metric for each node based, at least in part, on the corresponding set of node features and the corresponding class label; generating, by the server system, a sequence of node batches for training the student ML model based, at least in part, on the difficulty metric of each node, each node batch comprising a subset of nodes from the set of nodes in a predefined difficulty metric range associated with each node batch; initializing, by the server system, the student ML model based, at least in part, on one or more student model parameters; and training, by the server system, the student ML model to obtain a trained student ML model based, at least in part, on performing a first set of operations iteratively until a predefined criterion is met, the first set of operations comprising: selecting, by the server system, a node batch from the sequence of node batches; generating, by the student ML model, a set of node embeddings for the subset of nodes based, at least in part, on the set of node features of each node in the selected node batch; determining, by the student ML model, a set of positive embedding pairs and a set of negative embedding pairs from the set of node embeddings based, at least in part, on the attention score of each node in the subset of nodes; computing one or more losses comprising at least an attention-aided contrastive loss, wherein the attention-aided contrastive loss is computed by an attention-aided contrastive loss function based, at least in part, on the set of positive embedding pairs and the set of negative embedding pairs; and optimizing the one or more student model parameters based, at least in part, on the one or more losses, wherein for a subsequent iteration, a subsequent node batch is selected from the sequence of node batches.
  2. 2 . The computer-implemented method as claimed in claim 1 , wherein computing the one or more losses comprising at least a cross-entropy loss comprises: generating, by the student ML model, a set of probability scores for the subset of nodes based, at least in part, on the corresponding set of node embeddings; generating, by the student ML model, a a node class prediction for each node in the subset of nodes based, at least in part, on the set of probability scores, the node class prediction comprising a student-hard label prediction; and computing, by a cross-entropy loss function, the cross-entropy loss for each node based, at least in part, on the node class prediction and a ground truth label associated with the corresponding node.
  3. 3 . The computer-implemented method as claimed in claim 1 , wherein computing the one or more losses comprising at least a Kullback-Leibler (KL) divergence loss comprises: generating, by the student ML model, a probability score for each node in the subset of nodes based, at least in part, on the corresponding set of node embeddings; extracting, from a teacher ML model associated with the server system, a teacher probability score associated with the hard label prediction; and computing, by a KL divergence loss function, the KL divergence loss for each node based, at least in part, on the probability score and the teacher probability score of the corresponding node.
  4. 4 . The computer-implemented method as claimed in claim 1 , wherein determining the difficulty metric for each node comprises: determining, by the server system, a label metric for each node based, at least in part, on the corresponding class label; determining, by the server system, a feature metric for each node based, at least in part, on the corresponding set of node features; and computing, by the server system, the difficulty metric based, at least in part on the label metric and the feature metric.
  5. 5 . The computer-implemented method as claimed in claim 4 , wherein determining the label metric for each node comprises: identifying, by the server system, one or more neighbor nodes of each node; determining, by the server system, a class label corresponding to each neighbor node of the one or more neighbor nodes; and computing, by the server system, the label metric based, at least in part, on the corresponding class label of each node and the class label corresponding to each neighbor node.
  6. 6 . The computer-implemented method as claimed in claim 4 , wherein determining the feature metric for each node comprises: segregating, by the server system, a first subset of nodes associated with a first class label and a second subset of nodes associated with a second class label from the set of nodes based, at least in part, on the class label associated with each node; extracting, by the server system, from a teacher ML model, a first subset of teacher node embeddings for the corresponding first subset of nodes and a second subset of teacher node embeddings for the corresponding second subset of nodes based, at least in part, on a set of teacher node embeddings of the set of nodes; generating, by the server system, a first class representation representing a first class of the first subset of nodes based, at least in part, on an aggregation of the first subset of teacher node embeddings; generating a second class representation representing a second class of the second subset of nodes based, at least in part, on aggregation of the second subset of teacher node embeddings; and computing, by the server system, the feature metric based, at least in part, on comparing the first class representation, the second class representation, and a teacher node embedding corresponding to each node.
  7. 7 . The computer-implemented method as claimed in claim 1 , wherein determining the set of positive embedding pairs comprises: randomly selecting, by the server system, at least one node from the node batch as the reference node; accessing, by the server system, the set of node features associated with the reference node from the database; generating, by the server system, a reference node embedding for the reference node based, at least in part, on the set of reference node features; and identifying, by the server system, a first subset of node embeddings from the set of node embeddings that are related to the reference node embedding based, at least in part, on the class label of each node in the node batch to obtain the set of positive embedding pairs.
  8. 8 . The computer-implemented method as claimed in claim 7 , wherein determining the set of negative embedding pairs comprises: identifying, by the server system, a second subset of node embeddings from the set of node embeddings that are unrelated to the reference node embedding based, at least in part, on the class label of each node in the node batch to obtain the set of negative embedding pairs.
  9. 9 . The computer-implemented method as claimed in claim 1 , further comprising: accessing, by the server system, an entity-related dataset from the database, the entity-related dataset comprising information related to a plurality of entities; generating, by the server system, the set of features corresponding to each entity of the plurality of entities based, at least in part, on the information related to the plurality of entities; and generating, by the server system, the graph based, at least in part, on the set of features for each entity, wherein each particular node of the graph corresponds to each particular entity of the plurality of entities.
  10. 10 . The computer-implemented method as claimed in claim 1 , further comprising: accessing, by the server system, a training graph from the database, wherein the training graph comprises a set of training nodes comprising a set of training labeled nodes and a set of training unlabeled nodes connected through a set of training edges, wherein each training node in the set of training nodes is associated with a set of training node features and a training positional encoding and each training labeled node in the set of training labeled nodes is associated with a predefined label; initializing, by the server system, a teacher ML model based, at least in part, on one or more teacher model parameters; and training, by the server system, the teacher ML model based, at least in part, on performing, for the set of training nodes, iteratively until a teacher predefined criterion is met, a second set of operations comprising: generating, by the teacher ML model, a set of teacher node embeddings based, at least in part, on the corresponding set of training node features and a corresponding training positional encoding of each training node; determining, by the teacher ML model, a set of attention scores based, at least in part, on the set of teacher node embeddings; generating, by the teacher ML model, a teacher probability score for each training unlabeled node in the set of training labeled nodes based, at least in part, on the set of teacher node embeddings; generating, by the teacher ML model, a teacher node class prediction for each training unlabeled node based, at least in part, on the teacher probability score, the teacher node class prediction comprising the hard label prediction; computing, by a cross-entropy loss function, a teacher cross-entropy loss for each training unlabeled node based, at least in part, on the teacher node class prediction and a ground truth label associated with the corresponding unlabeled node; and optimizing the one or more teacher model parameters based, at least in part, on the teacher cross-entropy loss.
  11. 11 . The computer-implemented method as claimed in claim 1 , further comprising: accessing, by the server system, the graph from the database, wherein the graph comprises the set of nodes comprising a set of labeled nodes and a set of unlabeled nodes connected through a set of edges, wherein each node is associated with the set of node features and a positional encoding and each labeled node is associated with the predefined label; determining, by a teacher ML model associated with the server system, the attention score for each node based, at least in part, on the corresponding set of node features and the corresponding positional encoding of each node; and generating, by the teacher ML model, the hard label prediction for each unlabeled node in the set of unlabeled nodes based, at least in part, on the corresponding set of node features and the attention score.
  12. 12 . The computer-implemented method as claimed in claim 1 , further comprising: receiving, by the server system, a prediction request related to the downstream task for an entity associated with an individual node from the set of nodes; and generating, by the trained student ML model associated with the server system, a task-specific prediction corresponding to the downstream task for the individual node based, at least in part, on a corresponding plurality of node features of the individual node.
  13. 13 . A server system, comprising: a communication interface; a memory comprising executable instructions; and a processor communicably coupled to the communication interface and the memory, the processor configured to cause the server system to at least: access for each node of a set of nodes in a graph, a set of node features, a class label, and an attention score from a database associated with the server system, the class label comprising one of a predefined label and a hard label prediction, the attention score indicating an importance of each node with respect to a reference node in the graph; determine a difficulty metric for each node based, at least in part, on the corresponding set of node features and the corresponding class label; generate a sequence of node batches for training a student ML model based, at least in part, on the difficulty metric of each node, each node batch comprising a subset of nodes from the set of nodes in a predefined difficulty metric range associated with each node batch; initialize a student ML model based, at least in part, on one or more student model parameters; and train the student ML model based, at least in part, on a first set of operations that is performed iteratively until a predefined criterion is met, wherein the first set of operations comprise: select a node batch from the sequence of node batches; generate, by the student ML model, a set of node embeddings for the subset of nodes based, at least in part, on the set of node features of each node in the selected node batch; determine, by the student ML model, a set of positive embedding pairs and a set of negative embedding pairs from the set of node embeddings based, at least in part, on the attention score of each node in the subset of nodes; compute one or more losses comprising at least an attention-aided contrastive loss, wherein the attention-aided contrastive loss is computed by an attention-aided contrastive loss function based, at least in part, on the set of positive embedding pairs and the set of negative embedding pairs; and optimize the one or more student model parameters based, at least in part, on the one or more losses, wherein for a subsequent iteration, a subsequent node batch is selected from the sequence of node batches.
  14. 14 . The server system as claimed in claim 13 , wherein to compute the one or more losses comprising at least a cross-entropy loss, the server system is further caused, at least in part, to: generate, by the student ML model, a set of probability scores for the subset of nodes based, at least in part, on the corresponding set of node embeddings; generate, by the student ML model, a node class prediction for each node in the subset of nodes based, at least in part, on the set of probability scores, the node class prediction comprising a student-hard label prediction; and compute, by a cross-entropy loss function, the cross-entropy loss for each node based, at least in part, on the node class prediction and a ground truth label associated with the corresponding node.
  15. 15 . The server system as claimed in claim 13 , wherein to compute the one or more losses comprising at least a Kullback-Leibler (KL) divergence loss, the server system is further caused, at least in part, to: generate, by the student ML model, a probability score for each node in the subset of nodes based, at least in part, on the corresponding set of node embeddings; extract, from a teacher ML model associated with the server system, a teacher probability score associated with the hard label prediction; and compute, by a KL divergence loss function, the KL divergence loss for each node based, at least in part, on the probability score and the teacher probability score of the corresponding node.
  16. 16 . The server system as claimed in claim 13 , wherein to determine the difficulty metric for each node, the server system is further caused, at least in part, to: determine a label metric for each node based, at least in part, on the corresponding class label; determine a feature metric for each node based, at least in part, on the corresponding set of node features; and compute the difficulty metric based, at least in part on the label metric and the feature metric.
  17. 17 . The server system as claimed in claim 13 , wherein the server system is further caused, at least in part, to: access a training graph from the database, wherein the training graph comprises a set of training nodes comprising a set of training labeled nodes and a set of training unlabeled nodes connected through a set of training edges, wherein each training node in the set of training nodes is associated with a set of training node features and a training positional encoding and each training labeled node in the set of training labeled nodes is associated with a predefined label; initialize a teacher ML model based, at least in part, on one or more teacher model parameters; and train the teacher ML model based, at least in part, for the set of training nodes, iteratively until a teacher predefined criterion is met, a second set of operations comprising: generate, by the teacher ML model, a set of teacher node embeddings based, at least in part, on the corresponding set of training node features and a corresponding training positional encoding of each training node; determine, by the teacher ML model, a set of attention scores based, at least in part, on the set of teacher node embeddings; generate, by the teacher ML model, a teacher probability score for each training unlabeled node in the set of training labeled nodes based, at least in part, on the set of teacher node embeddings; generate, by the teacher ML model, a teacher node class prediction for each training unlabeled node based, at least in part, on the teacher probability score, the teacher node class prediction comprising the hard label prediction; compute, by a cross-entropy loss function, a teacher cross-entropy loss for each training unlabeled node based, at least in part, on the teacher node class prediction and a ground truth label associated with the corresponding unlabeled node; and optimize the one or more teacher model parameters based, at least in part, on the teacher cross-entropy loss.
  18. 18 . The server system as claimed in claim 13 , wherein, the server system is further caused, at least in part, to: access the graph from the database, wherein the graph comprises the set of nodes comprising a set of labeled nodes and a set of unlabeled nodes connected through a set of edges, wherein each node is associated with the set of node features and a positional encoding, and each labeled node is associated with the predefined label; determine, by a teacher ML model associated with the server system, the attention score for each node based, at least in part, on the corresponding set of node features and the corresponding positional encoding of each node; and generate, by the teacher ML model, the hard label prediction for each unlabeled node in the set of unlabeled nodes based, at least in part, on the corresponding set of node features and the attention score
  19. 19 . The server system as claimed in claim 13 , wherein the server system is further caused, at least in part, to: receive a prediction request related to the downstream task for an entity associated with an individual node from the set of nodes; and generate, by the trained student ML model associated with the server system, a task-specific prediction corresponding to the downstream task for the individual node based, at least in part, on a corresponding plurality of node features of the individual node.
  20. 20 . A non-transitory computer-readable storage medium comprising computer-executable instructions that, when executed by at least a processor of a server system, cause the server system to perform a method comprising: accessing for each node of a set of nodes in a graph, a set of node features, a class label, and an attention score from a database associated with the server system, the class label comprising one of a predefined label and a hard label prediction, the attention score indicating an importance of each node with respect to a reference node in the graph; determining a difficulty metric for each node based, at least in part, on the corresponding set of node features and the corresponding class label; generating a sequence of node batches for training the student ML model based, at least in part, on the difficulty metric of each node, each node batch comprising a subset of nodes from the set of nodes in a predefined difficulty metric range associated with each node batch; initializing the student ML model based, at least in part, on one or more student model parameters; and training the student ML model based, at least in part, on performing a first set of operations iteratively until a predefined criterion is met, the first set of operations comprising: selecting a node batch from the sequence of node batches; generating, by the student ML model, a set of node embeddings for the subset of nodes based, at least in part, on the set of node features of each node in the selected node batch; determining, by the student ML model, a set of positive embedding pairs and a set of negative embedding pairs from the set of node embeddings based, at least in part, on the attention score of each node in the subset of nodes; computing one or more losses comprising at least an attention-aided contrastive loss, wherein the attention-aided contrastive loss is computed by an attention-aided contrastive loss function based, at least in part, on the set of positive embedding pairs and the set of negative embedding pairs; and optimizing the one or more student model parameters based, at least in part, on the attention-aided contrastive loss, wherein for a subsequent iteration, a subsequent node batch is selected from the sequence of node batches.

Description

TECHNICAL FIELD The present disclosure relates to artificial intelligence-based processing systems and, more particularly, to electronic methods and complex processing systems for training a Machine Learning (ML) model such as a student ML model with graph structure information. BACKGROUND With the advent of technology, Machine Learning (ML) models have evolved to analyze and interpret complex datasets structured in networks or graphs. As may be understood, graphs can capture relational information between elements and hence can be used to represent complex datasets. A wide range of applications exist that involve complex datasets that can be represented in graphs, such as molecular structures in chemistry, social and commercial connections in a social network, payment network, citation network, etc. Conventionally, several Graph Neural Networks (GNNs) have been developed to learn insights from graph-structured data. GNNs leverage node features and graph structure to learn representations that capture the relational dependencies and patterns in the data. GNNs can be used for various graph-related tasks, such as node classification, link prediction, graph classification, recommendation systems, etc. However, GNNs fail to capture the global structure of the graphs due to over-smoothing and over-squashing issues. As a result, Graph Transformers (GTs) are developed as powerful alternatives to traditional GNNs, excelling in various graph-related tasks due to their ability to capture global information. More specifically, GTs, through their global attention mechanisms, can overcome the local structure bias of GNNs, offering State-Of-The-Art (SOTA) performance in various graph-related tasks. However, their adoption in resource-constrained environments is limited due to high inference times, primarily due to the quadratic computational complexity of the attention mechanism. On the other hand, Multilayer Perceptron (MLP)-based models and other ML models with simpler model architecture are favorable model architectures for rapid inference. However, such model architectures cannot process a graph's structural information, leading to a compromised performance in relational learning tasks. Further, despite their inability to utilize the graph's structural and relational information effectively, the MLP-based models are preferred for rapid inference. Furthermore, although model compression through pruning and quantization have been explored to accelerate transformer inference, they often involve trade-offs. For example, structured pruning can streamline the model to suit deployment constraints, yet it might not always preserve optimal accuracy, especially for complex graph structures or larger node sets. This complexity, driven by the attention mechanism's exhaustive node-to-node interactions, underscores the challenge of balancing performance and efficiency in GT deployments. To address this problem, conventionally, several approaches have been implemented. These approaches consider the possibility of combining the benefits of both graph-based models and MLPs using knowledge distillation. As may be understood, knowledge distillation refers to the process of transferring knowledge learned by larger models (i.e., a teacher model) to a smaller model (i.e., a student model). It is noted that the conventional approaches involve knowledge distillation from GNNs or GCNs to MLPs. One such approach uses logits to distill knowledge from the teacher model to the student MLP, which cannot completely capture graph structure information. To address this problem, another approach is proposed, that extracts node position features from a graph along with node features and uses them to cover the structural information at the student MLP during inference. However, this approach is also associated with several drawbacks. One such drawback is that the student MLP requires graph structure information during inference. Another drawback lies in a technique that uses the local structural information from truncated random walks to learn latent representations. This is more suitable for message-passing GNNs which also utilize local structure information, rather than GTs that rely on attention mechanisms to capture global structure information, especially for large graphs. Thus, there exists a need for technical solutions, such as improved methods and systems for training an ML model with graph structure information while overcoming the aforementioned technical drawbacks. SUMMARY Various embodiments of the present disclosure provide methods and systems for training a Machine Learning (ML) model with graph structure information. In an embodiment, a computer-implemented method for training a Machine Learning (ML) model with graph structure information is disclosed. The computer-implemented method performed by a server system includes accessing for each node of a set of nodes in a graph, a set of node features, a class label, and an attention score f