CN-122023789-A - Semi-supervised medical image segmentation method based on diffusion-driven hard-soft prototype contrast learning
Abstract
A semi-supervised medical image segmentation method based on diffusion-driven hard-soft prototype contrast learning belongs to the field of medical image processing and comprises the steps of obtaining a medical image, dividing a label data set and a label-free data set, generating a pixel-level confidence level map by means of a discriminator, conducting self-adaptive denoising and correction on a noise pseudo-label by means of a confidence level map guiding diffusion model, generating a correction pseudo-label by means of a local detail enhancement mechanism, constructing a hard-soft collaborative prototype contrast module, constructing a class prototype based on the correction pseudo-label, generating a hard matching indication matrix on high confidence pixels, applying soft guiding constraint on low confidence pixels, enhancing intra-class compactness and inter-class separability of feature space, conducting iterative optimization on the model through a joint optimization objective function, inputting the medical image to be segmented into a segmentation network after training to obtain segmentation results, solving the problems of pseudo-label noise accumulation and feature discrimination degradation in semi-supervised learning, and reducing dependence on large-scale labeling data.
Inventors
- DU XIAOGANG
- LI CHUNLIANG
- ZHANG JIAWEI
- ZHOU LEI
- LEI TAO
- LIU TONGFEI
- WANG YINGBO
Assignees
- 陕西科技大学
Dates
- Publication Date
- 20260512
- Application Date
- 20251231
Claims (5)
- 1. A semi-supervised medical image segmentation method based on diffusion-driven hard-soft prototype contrast learning is characterized by comprising the following steps: Step 1, acquiring a medical image dataset, and dividing the medical image dataset into a labeled dataset And a non-labeled dataset Preprocessing a data set, wherein the weak enhancement operation of random overturn and small angle rotation is applied to the labeled data, the weak enhancement operation of random overturn and small angle rotation and the strong enhancement operation of color disturbance and Gaussian blur are applied to the unlabeled data, and sample data for model training is generated; Initializing a segmentation network model, a discriminator and a diffusion model, and setting super parameters required by training, wherein the super parameters at least comprise a confidence threshold used for representing the reliability of a prediction pseudo tag, the time step number of the diffusion model and weight coefficients corresponding to each loss function, the segmentation network model is a neural network based on an encoding-decoding structure, the pre-processed tagged data and untagged data in the step 1 are respectively input into the segmentation network model, and weak enhancement pseudo tags with the tagged data are correspondingly output And weak enhanced pseudo tag for non-tag data Strong reinforced pseudo tag ; Step 3, generating a confidence map by utilizing a discriminator, and for the tagged data, utilizing the discriminator to carry out the identification of the real tag GT and the weak enhancement of the pseudo tag Discriminating and generating a pixel-level confidence map reflecting the prediction reliability of each pixel Generating a weak enhancement pseudo tag by a discriminator based on consistency information among the prediction pseudo tags of the segmentation network under different enhancement conditions for the non-tag data With strong enhancement pseudo tag Reliability pixel level confidence map ; Simultaneously, introducing a local detail enhancement mechanism into an encoder of the diffusion model, dynamically calculating the size of a window according to the size of the feature map, dividing the feature map into overlapped local windows, and parallelly executing a multi-head self-attention mechanism and depth separable convolution in the window to capture fine-granularity anatomical structure features, so as to generate a diffusion model correction pseudo-tag corrected by noise removal and structural features; Step 5, constructing a hard-soft collaborative prototype comparison module, firstly, performing spatial scale alignment on the diffusion model correction pseudo-label obtained in the step 4 and the pixel level feature image extracted from the last layer of the segmentation network decoder, constructing an instant class prototype by aggregating the pixel features of the same class, and performing time sequence updating on the global prototype of each class by adopting an index moving average strategy to obtain stable global semantic representation And pixel level class index Calculating a pixel feature vector Similarity with the global category prototype to obtain similarity distribution of each pixel relative to each category prototype, and the similarity distribution is used for describing category association relation of the pixels in the feature space and combining pixel level category indexes The method comprises the steps of establishing a global category prototype, establishing a differential constraint mechanism, namely, for high-confidence pixels, generating a hard matching indication matrix, definitely limiting a positive sample prototype and a negative sample prototype corresponding to the high-confidence pixels, establishing a soft guide constraint mechanism, allowing the pixels and a plurality of candidate category prototypes to establish a weighted association relationship so as to reduce interference of noise pseudo labels on feature learning, and finally, introducing a pixel-prototype comparison loss function into a similarity calculation result and a positive sample prototype and a negative sample prototype relationship together to optimize, enhancing intra-class compactness and inter-class separability of feature space by restricting consistency between pixel features and the global prototype, and improving positioning accuracy and integral segmentation performance of focus region boundaries; Step 6, constructing a joint optimization objective function, performing joint training on the segmentation network, the discriminant and the diffusion model, iteratively updating model parameters based on the joint optimization objective function, and repeatedly executing the steps 3 to 5 until the model training converges; and 7, inputting the medical image to be segmented into a segmentation network after model training is completed, and directly generating a corresponding medical image segmentation result through end-to-end forward reasoning.
- 2. The semi-supervised medical image segmentation method based on diffusion driven hard-soft prototype contrast learning of claim 1, wherein the generating a confidence map with a discriminant in step 3 is as follows: Step 3.1 for tagged data Distinguishing device Receiving authentic labels Weak enhancement pseudo tag Outputting a pixel level confidence map in a single thermal encoding form Expressed as: wherein: , as a function of the Sigmoid, The closer to 1 represents a pixel The greater the probability from a real label, the closer to 0 the pixel is represented From weak enhanced pseudo tags The greater the probability of (2); step 3.2 for unlabeled data First, weak enhancement pseudo tag generated by dividing network Encoding into potential space Using diffusion model pairs Performing noise adding and denoising treatment to obtain an optimized weak enhancement pseudo tag Distinguishing device Receiving diffusion-initially optimized weak enhanced pseudo tags And a strongly enhanced pseudo tag generated by a split network Outputting a pixel level confidence map Expressed as: wherein: Representing unlabeled samples at pixels Confidence of pseudo tag at the position, and the larger the value is, the weak enhancement pseudo tag corresponding to the pixel position is represented The more reliable.
- 3. The semi-supervised medical image segmentation method based on diffusion driven hard-soft prototype contrast learning of claim 2, wherein the process of step 4 is: step 4.1, defining an adaptive weighting formula, and converting the pixel-level confidence coefficient diagram into an adaptive adjustment weight of the diffusion model : Wherein: , representing the presentation to be Cutting off the function in the interval [0.5,1], distributing higher correction weight to the low confidence pixels through the mapping, and ensuring that the high confidence pixels have lower weight, so as to avoid excessive disturbance to the reliable area, thereby keeping the stability of the whole pseudo tag; Step 4.2, further combining pixel level Dice loss to describe the segmentation error, and constructing an adaptive correction index : Wherein: And Representing a basic calibration coefficient and a maximum calibration coefficient respectively; Step 4.3, adaptively correcting the index First by sine and cosine position coding Mapping into high-dimensional embedding, and then performing feature enhancement through two layers of linear mapping and Swish activation functions to obtain an embedded vector Then for the embedded vector Performing linear projection, and extracting the projection result and the diffusion model from the network backbone in the current denoising time step Fusion is carried out: wherein: A linear projection operation is represented and is shown, The fused features are used for guiding the reverse denoising process of the diffusion model; Step 4.4, giving the original image in the forward diffusion stage In time steps Noise image of (a) Can be expressed as: wherein: , subject to standard normal distribution, i.e. ; In the process of reverse denoising, the characteristics of the diffusion model after condition input and fusion are carried out Gradually recovering samples under the common guidance of the system, wherein the condition input is corresponding real label GT when label data are processed, and is pseudo label preliminarily corrected by a diffusion model when label-free data are processed : Wherein: representing a noise prediction function constructed by a diffusion model for predicting noise residuals based on noisy samples of a current time step, a number of time steps, and the pseudo tag information to guide a gradual denoising process, the diffusion model utilizing the number of time steps Controlling global noise dynamics, embedding vectors By embedding into fused features In the process, the diffusion process is indirectly regulated, so that self-adaptive correction of pixel level and fine-grained pseudo tag optimization are realized; Step 4.5, implementing progressive pseudo-tag diffusion correction, firstly processing weak enhancement data, mapping an initial weak enhancement pseudo-tag generated by a segmentation network into a potential space, for tagged data, guiding a diffusion denoising process by using a real tag GT as a priori condition to learn noise distribution, for untagged data, performing self-adaptive denoising and correction on the weak enhancement pseudo-tag by using a diffusion model to generate an optimized weak enhancement pseudo-tag And then processing the strong enhancement data to obtain the optimized weak enhancement pseudo tag The guide condition is used as a guide condition input diffusion model, and the reliable structure information contained in the guide condition is used for carrying out secondary correction on the strong enhancement pseudo tag, so that the corrected strong enhancement pseudo tag is finally obtained ; Step 4.6, introducing a local detail enhancement mechanism in the encoder structure of the diffusion model, first according to the intermediate feature map, in order to prevent the diffusion process from leading to a smooth loss of fine anatomy Dynamically calculating window size for a size of (2) Sum step size Will be an intermediate feature map Dividing into a plurality of overlapping partial windows Capturing fine granularity features in each window through a multi-head self-attention mechanism and depth separable convolution, and calculating to obtain enhanced local window features : Wherein: the representation GELU activates the function, Representing a depth-separable convolution operation, attention represents a multi-headed self-Attention mechanism, Representing the result of the residual connection of the self-attention output and the original window feature, And Representing a linear projection; After the enhancement of each local window feature is completed, the enhancement features in the overlapped area are subjected to average fusion according to the pixel positions so as to generate an enhancement feature map with the same size as the original intermediate feature map : Wherein: Representing pixels The set of windows in which to reside, Representing coverage of pixel locations for all spatial extents Is (are) local window index Traversing and accumulating.
- 4. The semi-supervised medical image segmentation method based on diffusion driven hard-soft prototype contrast learning of claim 3, wherein the process of the hard-soft collaborative prototype contrast module in step 5 is: step 5.1, extracting the pixel level feature map from the last layer of the split network decoder Wherein 、 、 、 Respectively representing batch size, channel number, height and width of feature map, and using them to represent high-level semantic information of medical image, correcting diffusion model and strongly enhancing pseudo-label The correction pixel level feature map is obtained by adjusting the spatial scale alignment operation to the same spatial resolution as the pixel level feature map Establishing a one-to-one correspondence between pixel-level features and class labels; Step 5.2, based on the pseudo tag aligned in step 5.1, aggregating the pixel-level features according to categories, and constructing an instant category prototype under the current training iteration by carrying out weighted average on the pixel features belonging to the same category, wherein the instant category prototype is expressed as follows: wherein: Represent the first The samples are at pixel coordinates Is used for the feature vector of (a), Represented as an indication function, Representing a small constant for avoiding division by zero; step 5.3, adopting an exponential moving average strategy to update the time sequence of the global prototype of each category, when a certain category When first appearing, directly initialize to corresponding instant category prototype When the category has a history prototype, the current instant category prototype and the history global prototype are weighted and fused according to a preset momentum coefficient to obtain an updated global category prototype, which is expressed as: wherein: Expressed as a coefficient of momentum, Representing an existing historical prototype; step 5.4, the pixel level feature map Flattened into pixel feature vectors Then to And global category prototypes for corresponding categories L2 normalization is carried out to eliminate scale difference, and the matching degree of the normalized vector is calculated through cosine similarity and expressed as: wherein: Is a temperature parameter for adjusting the sharpness of the profile of Softmax; Step 5.5, correcting the pixel level feature map Flattening into pixel-level class index along spatial dimensions Constructing a hard matching indication matrix according to the class label of each pixel, wherein the hard matching indication matrix is expressed as: wherein: denoted as the first The method comprises the steps of selecting a matrix, namely, 1 in the matrix to represent a positive sample pair, 0 in the matrix to represent a negative sample pair, and definitely defining a positive sample relationship and a negative sample relationship, wherein only pixels with prediction confidence higher than a preset threshold value are selected to participate in the hard constraint matching so as to reduce the influence of noise pseudo labels on subsequent steps; step 5.6, carrying out normalization processing on matching relations between pixels and a plurality of category prototypes to which the pixels possibly belong on the basis of a hard matching instruction matrix, and constructing a soft guide constraint mechanism, wherein the soft guide constraint mechanism is expressed as follows: allowing the pixel characteristics to establish weighted association with other candidate category prototypes while aggregating the category prototypes to which the pixel characteristics belong, so as to relieve misleading caused by category boundary blurring and false label uncertainty; And 5.7, calculating the contrast loss of the hard-soft collaborative prototypes, namely constructing a pixel-prototype contrast loss function based on the pixel-prototype similarity calculation, the hard matching relation and the soft target distribution, wherein the pixel-prototype contrast loss function is expressed as follows: wherein: the method comprises the steps of representing a high-confidence pixel set participating in prototype comparison learning, and guiding pixel-level features to aggregate towards class-like prototypes and away from class-like prototypes through a joint hard-match indication matrix and a soft-guide constraint mechanism so as to enhance intra-class compactness and inter-class separability of feature representations; Step 5.8, constructing boundary loss based on the pixel level feature map and the prediction boundary generated by correcting the pixel level feature map The pixel-prototype contrast loss and boundary loss Weighted fusion is carried out to form hard-soft cooperative prototype contrast loss Expressed as: wherein: Expressed as a balance coefficient.
- 5. The semi-supervised medical image segmentation method based on diffusion driven hard-soft prototype contrast learning of claim 4, wherein the joint optimization objective function in step 6 is expressed as: step 6.1, constructing a split network loss, wherein the optimization target of the split network is monitored by loss Loss of compliance with unsupervised Co-construct, wherein the supervision loss is based on a tagged dataset Calculation by enhancing pseudo tags for weak The overlapping consistency with the real label GT and the pixel classification error are jointly constrained to improve the basic segmentation capability of the model, and the basic segmentation capability is expressed as follows: wherein: The number of marked samples is represented as such, Weight coefficients for cross entropy loss; at the same time, by comparing the unlabeled dataset And applying consistency constraint on the prediction pseudo tag under the weak enhancement and strong enhancement conditions, and constructing unsupervised consistency loss to inhibit prediction instability and improve model generalization capability, wherein the method is expressed as follows: In the formula, Representing the number of unlabeled samples; step 6.2, constructing a loss of the discriminator in the labeled dataset On top of that, the arbiter distinguishes the real tag GT from the weak enhanced pseudo tag generated by the split network Learning the pixel level discrimination capability, its loss function definition is expressed as: wherein: Indicating that the pixel is from a real label A weak enhanced pseudo tag representing the pixel from the split network; For pixels in unlabeled data Pixel level consistency score Is calculated based on pseudo labels under weak enhancement and strong enhancement, and is expressed as follows: higher values of (2) indicate that the pixel pseudo tag is more reliable, while lower values indicate that there is uncertainty in the prediction, In a non-labeled dataset On, utilize A binary cross entropy penalty is constructed to smooth the training process and provide progressive supervision, expressed as: step 6.3, constructing a diffusion model correction loss, including weak enhancement correction loss and strong enhancement correction loss, aiming at the tagged data, using a real tag GT as a determined priori condition, and supervising the diffusion model to weakly enhance the pseudo tag In the denoising process of the potential space, the weak enhancement correction loss is expressed as: wherein: Representing marked samples In the diffusion process of Potential representation of steps, an Representing its potential state in the previous step; second, for unlabeled data, weak enhanced pseudo-labels are utilized Driving diffusion model pair strong enhancement pseudo tag as guiding condition And carrying out consistency correction, and forcing the pseudo tag under strong disturbance to be aligned to a weak enhancement pseudo tag with high reliability by minimizing noise prediction residual error, wherein the strong enhancement correction loss is expressed as: wherein: Representing unlabeled samples In the diffusion process of Potential representation of steps; representing potential representation of the previous time step after weak pseudo tag optimization; Step 6.4, on the basis of the above sub-losses, constructing a joint optimization objective function of a segmentation network, a discriminator and a diffusion model, and simultaneously introducing hard-soft collaborative prototype contrast loss to enhance intra-class compactness and inter-class separability of a feature space, wherein the joint optimization objective function is expressed as: wherein: 。
Description
Semi-supervised medical image segmentation method based on diffusion-driven hard-soft prototype contrast learning Technical Field The invention belongs to the technical field of medical image processing and computer-aided diagnosis, and particularly relates to a semi-supervised medical image segmentation method based on diffusion driving hard-soft prototype contrast learning. Background Medical image segmentation is a fundamental link in computer-aided medical analysis, which aims at dividing a medical image into anatomical regions or regions of interest with defined semantics. Accurate organ and tissue segmentation can provide reliable quantitative analysis indexes and structural information for clinical diagnosis, and is an important precondition for clinical application such as early disease screening, focus evaluation, operation planning and the like. In recent years, with the development of deep learning technology, a full-supervised medical image segmentation method based on convolutional neural network CNN and a transducer architecture obtains a better segmentation effect in various medical imaging tasks. However, such methods generally rely on a large amount of high-quality pixel-level labeling data, and the fine labeling process of medical images is not only high in cost and long in time consumption, but also highly depends on the experience of domain experts, and particularly in complex anatomical structures or multi-organ scenes, the labeling difficulty is remarkably increased, so that popularization and application of the full-supervision medical image segmentation method in actual clinic are limited. In order to reduce the dependence on large-scale labeling data, semi-supervised medical image segmentation methods are becoming widely concerned. The method performs model training by combining a small amount of tagged data with a large amount of untagged data, and mainly comprises two technical routes of pseudo tag learning and consistency regularization. The consistency regularization method is based on the assumption that the prediction results of the model under different disturbance or enhancement conditions should be kept consistent, and the generalization capability of the model is improved by restraining the multi-view prediction consistency. Although the above method alleviates the problem of insufficient annotation data to some extent, two key technical bottlenecks related to each other and affecting each other still face in practical application, resulting in limited segmentation performance. 1. Pseudo tags are of insufficient quality and lack an effective adaptive correction mechanism. Because the initial training stage only depends on limited labeling samples, the pseudo labels generated by the model on the label-free data often contain more noise, and are represented by fuzzy segmentation boundaries, incomplete structures or wrong anatomical relations. In the prior art, a fixed confidence threshold value or an uncertainty estimation-based method is adopted to screen the pseudo tag, but the method is generally based on a static or global strategy, is difficult to describe the reliability difference of a pixel level or a local structure level, and has limited adaptability to a complex anatomical region. The low quality pseudo tag is easily and continuously amplified in the iterative training process, and the constraint effect of consistency regularization in a low confidence region is weakened. Although some technical schemes attempt to introduce a diffusion model to generate or optimize a pseudo tag, due to the lack of a pixel-level and dynamic confidence coefficient guiding mechanism, the problem of excessively smoothing a reliable region or insufficiently correcting a noise region still easily occurs in the denoising process, and the consistency of boundary precision and an overall structure is difficult to be considered. 2. The characteristic discriminant degradation and the characteristic space distribution are chaotic. Most of the existing semi-supervised medical image segmentation methods focus on the supervision constraint of a label layer, but pay insufficient attention to semantic distribution structures in a feature space. The introduction of noise pseudo tags can destroy the discriminant of the feature representation, resulting in loose intra-class feature distribution and fuzzy inter-class boundaries. Although some methods attempt to enhance feature discrimination capability through prototype contrast learning, static prototype construction is mostly relied on, and a single hard allocation or simple clustering strategy is generally adopted in the correlation modeling of pixels and prototypes, so that a mechanism for performing differential constraint on high-confidence and low-confidence regions is lacking, and complex feature distribution conditions are difficult to effectively adapt. The characteristic discriminant degradation in turn further reduces the quality of the ps