Search

CN-120565113-B - Training method and classifying method for multi-modal classification model and multi-modal classification model

CN120565113BCN 120565113 BCN120565113 BCN 120565113BCN-120565113-B

Abstract

The application discloses a training method of a multi-modal classification model, a classifying method and the multi-modal classification model, wherein the training method is realized by respectively extracting characteristics of a first image sample, a second image sample, a first health related sample and a second health related sample, carrying out corresponding multi-granularity contrast learning and personality commonality representation learning on the extracted first image characteristics and the first health related characteristics to obtain contrast learning loss and personality commonality representation learning loss so as to carry out iterative updating on the multi-modal classification model, carrying out personality commonality representation learning on the extracted second image characteristics and the second health related characteristics to obtain personality representation and commonality representation, carrying out weighting and splicing on the extracted second image characteristics and the second health related characteristics to obtain spliced characteristics, carrying out classification processing on the spliced characteristics to obtain classification loss, and carrying out iterative updating on the multi-modal classification model again. And a two-stage training strategy is adopted, so that the adaptability of the model to complex multi-mode data and the classification accuracy are improved.

Inventors

  • WANG CHANGMIAO
  • WANG ZHIPENG
  • TIAN YUAN
  • LUO XUEJIAO
  • ZHOU XIULAN
  • GE RUIQUAN
  • PENG BAO
  • SHAO ZHUHONG
  • WAN XIANG

Assignees

  • 深圳市大数据研究院

Dates

Publication Date
20260508
Application Date
20250526

Claims (10)

  1. 1. A method of training a multi-modal classification model, applied to a multi-modal classification model, the method comprising: Respectively extracting features of the first image sample, the second image sample, the first health related sample and the second health related sample to obtain a first image feature, a second image feature, a first health related feature and a second health related feature; Performing multi-granularity contrast learning and personality commonality representation learning on the first image features and the first health related features, wherein the multi-granularity contrast learning and personality commonality representation learning comprises the steps of performing feature alignment on different levels by utilizing a contrast learning loss function to determine intermodal loss, intra-modal loss and sample level loss, calculating to obtain contrast learning loss based on intermodal loss, intra-modal loss and sample level loss, decomposing the contrast-learned trans-modal features into commonality features and personality features through an encoder, calculating to obtain commonality loss based on the commonality loss function, calculating to obtain personality loss based on the personality loss function, obtaining personality commonality representation learning loss based on the personality loss and the commonality loss, obtaining total learning loss based on total contrast learning loss and personality commonality representation learning loss, and performing iterative updating on the model if the total learning loss is larger than a preset loss threshold or the iteration number is smaller than the preset number; performing personalized commonality representation learning on the second image features and the second health related features to obtain corresponding personalized representations and commonality representations under different loss constraints; the personalized representation and the commonality representation are weighted and then spliced to obtain splicing characteristics, and classification processing is carried out based on the splicing characteristics to obtain classification loss; and based on the classification loss, iteratively updating the multi-mode classification model again.
  2. 2. The method for training the multi-modal classification model according to claim 1, wherein the performing multi-granularity contrast learning and personality commonality representation learning on the first image feature and the first health-related feature to obtain contrast learning loss and personality commonality representation learning loss includes: performing multi-granularity contrast learning on the first image features and the first health related features to obtain contrast learning loss and cross-modal features, wherein the contrast learning loss comprises intra-modal contrast learning loss, inter-modal contrast learning loss and sample-level contrast learning loss; and performing personalized commonality representation learning on the cross-modal characteristics to obtain the personalized commonality representation learning loss.
  3. 3. The method for training the multi-modal classification model according to claim 2, wherein the performing the personalized commonality representation learning on the cross-modal feature to obtain the personalized commonality representation learning loss includes: determining a personality presentation learning penalty based on the personality presentation; Determining a commonality representation learning loss based on the commonality representation; And obtaining the personality commonality representation learning loss based on the personality representation loss and the commonality representation loss.
  4. 4. The method for training a multi-modal classification model according to claim 1, wherein the weighting and stitching the personality and commonality representations to obtain stitching features includes: determining sample quality of the image sample and the health related sample corresponding to the personality presentation; based on the sample quality, determining personality weights and commonality weights of the corresponding personality representations; respectively giving weight to the personality representation and the commonality representation to the personality weight and the commonality weight; and performing characteristic splicing on the personalized representation and the commonality representation after the weighting to obtain the splicing characteristics.
  5. 5. The method of claim 4, wherein determining the sample quality of the image sample and the health-related sample corresponding to the personality presentation comprises: creating a positive sample prototype and a negative sample prototype based on the personalized feature mean value corresponding to all training samples; determining similarity scores of the positive sample prototype and the negative sample prototype and the personality representation respectively; And determining sample quality of the image sample and the health related sample corresponding to the personality representation based on the similarity score and the real label.
  6. 6. The method of claim 5, wherein the classification loss comprises a prototype update loss and a predictive loss, and wherein the classifying based on the stitching features to obtain the classification loss comprises: Updating the positive and negative sample prototypes based on the high quality sample and determining prototype update loss; carrying out classified prediction on the spliced features to obtain a classified prediction result; obtaining the prediction loss based on the classification prediction result and the real label; based on the prototype update loss and the prediction loss, a classification loss is obtained.
  7. 7. The method of training a multi-modal classification model according to any one of claims 1-6, wherein the personality representation includes a personality image representation and a personality health related representation, the commonality representation includes a commonality image representation and a commonality health related representation, the weighting of the personality representation and the commonality representation, the stitching, resulting in stitching features, and the classifying based on the stitching features, comprises: the personalized image representation and the common image representation are weighted and then spliced to obtain image splicing characteristics, and the personalized health related representation and the common image representation are weighted and then spliced to obtain health related splicing characteristics; obtaining an image classification result based on the image stitching feature prediction, and obtaining a health-related classification result based on the health-related stitching feature prediction; And respectively weighting the image classification result and the health related classification result, and obtaining a classification result based on the weighting result.
  8. 8. A method for classifying a multi-modal classification model, applied to a multi-modal classification model trained by the training method of the multi-modal classification model according to any one of claims 1-7, the method comprising: Acquiring image data to be detected and health related data to be detected, and respectively extracting features of the image data to be detected and the health related data to be detected to obtain image features to be detected and health related features to be detected; Respectively carrying out personalized commonality representation learning on the image features to be detected and the health related features to be detected to obtain corresponding personalized image representations, commonality image representations, personalized health related representations and commonality health related representations under different loss constraints; Splicing the individual image representation and the common image representation to obtain image splicing characteristics, and splicing the individual health related representation and the common image representation to obtain health related splicing characteristics; obtaining an image classification result based on the image stitching feature prediction, and obtaining a health-related classification result based on the health-related stitching feature prediction; and respectively weighting the image classification result and the health related classification result, and obtaining a comprehensive classification result based on the weighting result.
  9. 9. The multi-modal classification model is characterized in that the multi-modal classification model is used for realizing the training method of the multi-modal classification model according to any one of claims 1-7, and specifically comprises a multi-modal classification model corresponding to a first training stage and a multi-modal classification model corresponding to a second training stage, wherein the multi-modal classification model corresponding to the second training stage is obtained based on the multi-modal classification model corresponding to the first training stage; The multi-mode classification model corresponding to the first training stage comprises an image feature extraction module, a health related feature extraction module and a feature learning module, wherein the feature learning module is used for executing multi-granularity contrast learning and personality commonality representation learning on the extracted image features and the health related features; The multi-mode classification model corresponding to the second training stage comprises an image feature extraction module, a health related feature extraction module, a feature learning module and a classification module, wherein the feature learning module is used for executing personality commonality representation learning on the extracted image features and the health related features, and the classification module is used for carrying out classification prediction on the personality representations and the commonality representations learned by the feature learning module.
  10. 10. The multi-modal classification model of claim 9, wherein the multi-modal classification model corresponding to the second training stage includes an image-modality prediction branch including an image feature extraction module, an image feature learning module, and an image classification module, and the health-related-modality prediction branch includes a health-related feature extraction module, a health-related feature learning module, and a health-related classification module for outputting an image prediction result, and the health-related classification module for outputting a health-related prediction result, and performing weighted fusion of the image prediction result and the health-related prediction result to obtain a final prediction result.

Description

Training method and classifying method for multi-modal classification model and multi-modal classification model Technical Field The application relates to the technical field of multi-modal fusion, in particular to a training method and a classifying method of a multi-modal classifying model and the multi-modal classifying model. Background Currently, doctors often diagnose by combining multi-mode information such as 3D medical images, electronic Health Records (EHRs), medical reports and the like, so that the health condition of the patient can be more comprehensively known, and the accuracy and efficiency of diagnosis are improved. Wherein 3D imaging data such as CT and MRI can generate high resolution images, accurately identify internal details, and EHR provides comprehensive background information such as clinical history, demographic data, and drug use. However, the traditional clinical diagnosis method is highly dependent on experience and manual analysis of doctors, has obvious limitations in processing complex multi-mode data, is time-consuming and labor-consuming when the doctors repeatedly read the films when facing a large number of CT images, is easy to generate errors due to workload and time pressure, and can influence the diagnosis accuracy due to the fact that the quality of the CT images is limited by equipment and technology. In addition, the traditional method has extremely high requirements on the professional knowledge of doctors, non-professional staff are difficult to participate, the acquisition period of multi-source data (such as images, texts and inspection indexes) is long, and the workload of the doctors is further increased by integrating the multi-source data. To break through these limitations, multi-modal AI technology has evolved. However, existing AI models still suffer from significant drawbacks in processing cross-modal data. The method can integrate multiple data modes, but is difficult to mine deep semantic association among different modes due to lack of effective mode interaction design, and the existing method usually ignores multi-granularity characteristics of features and quality differences inside the modes, so that a model still has a large improvement space in interaction details and stability. Disclosure of Invention Based on the foregoing, it is necessary to provide a training method and a classifying method for a multi-modal classification model and the multi-modal classification model to solve at least one of the problems in the prior art. In a first aspect, a training method of a multi-modal classification model is provided, which is applied to the multi-modal classification model, and the method includes: Respectively extracting features of the first image sample, the second image sample, the first health related sample and the second health related sample to obtain a first image feature, a second image feature, a first health related feature and a second health related feature; performing multi-granularity contrast learning and personality commonality representation learning on the first image features and the first health related features to obtain contrast learning loss and personality commonality representation learning loss, and performing iterative updating on the multi-mode classification model based on the contrast learning loss and personality commonality representation learning loss; performing personalized commonality representation learning on the second image features and the second health related features to obtain corresponding personalized representations and commonality representations under different loss constraints; the personalized representation and the commonality representation are weighted and then spliced to obtain splicing characteristics, and classification processing is carried out based on the splicing characteristics to obtain classification loss; and based on the classification loss, iteratively updating the multi-mode classification model again. In one possible implementation manner, the performing multi-granularity contrast learning and personality commonality expression learning on the first image feature and the first health related feature to obtain contrast learning loss and personality commonality expression learning loss includes: performing multi-granularity contrast learning on the first image features and the first health related features to obtain contrast learning loss and cross-modal features, wherein the contrast learning loss comprises intra-modal contrast learning loss, inter-modal contrast learning loss and sample-level contrast learning loss; and performing personalized commonality representation learning on the cross-modal characteristics to obtain the personalized commonality representation learning loss. In one possible implementation manner, the learning of the personality commonality representation of the cross-modal feature to obtain the personality commonality representation learning loss includ