3D图像分割:Thickened 2D Networks for 3D Medical Image Segmentation

Thickened 2D Networks for 3D Medical Image Segmentation

https://arxiv.org/pdf/1904.01150.pdf

 

本文研究的是血管分割(blood vessels),使用的是3D CT图像。将深度学习用于3D image segmentation,具体到医学图像分割中是三维空间的二分类问题。即确定3D voxel体素是属于前景(foreground)还是属于背景(background)。将深度学习用到三维医学图像分割主要有如下两种思路:(1)将3D图像切成一个个的2D切片slices,然后对于每个2D slices使用自然图像/计算机视觉中常用的分割模型,最后将对于每个2D slices分割后的结果拼接成3D形式。2D切片的解决方案充分利用了整张2D切片中的信息,但是却忽略了几张相邻切片图像之间的关系,损失了全局的slices序列之间的关联性(2)使用3D CNN直接分割3D图像,由于3D CNN网络模型的计算量本身就很大,直接在原始高分辨率3D块上使用3D CNN很可能导致显存不够,故而需要将3D 原始分辨率切成一个个的3D patches,再在每个patches上应用3D CNN。优点是能够应用全局切片序列信息,缺点是缺少3D CNN的预训练模型参数。

本文所提出的模型是在使用2D CNN分割CT图像slices的模型上进行改进。对于3D CT扫描图像进行分割的问题描述为:输入3D volume X,shape=H*W*L,L表示切片厚度,它的ground truth label是shape=H*W*L的binary segmentation map,其中每个体素点的取值为1或者0,1表示当前体素点是前景,0表示背景。将y记作3D块的ground truth segmentation labels,z记作3D块的prediction segmentation map。

一、原始的2D CNN分割3D图像baseline模型

将shape=H*W*L的3D体素输入切片成多个2D slices,具体会在3个坐标轴上都进行切分,将在H轴,W轴和L轴上的2D片集合分别记作X(S,h),X(C,w),X(A,l)。在每个轴上分别应用2D CNN。假设L轴为例,与L轴相对应的会有一套分割网络权值参数,将原始的3D块切分成L个H*W的2D切片,每个2D切片都会作为训练数据送入L轴对应的分割网络中训练,最终对于H,W,L三个坐标轴将会训练出3套模型参数,但是三个坐标轴使用的是相同的模型结构。测试时,将测试的3D CT块分别沿H,W和L轴切割成三个部分,每个部分(2D图像序列)分别送入当前轴所对应的网络中,再将在当前轴下的所有分割结果concatenate,最终得到当前轴的模型参数下,对于3D切片所预测的输出结果,将在3个坐标轴权值参数下得到的3个分割结果每个体素处的前景概率值取平均,作为最终的预测结果。很显然,这样只是将3D扫描图像切分成了一个个2D切片而已,并没有考虑到slices之间的instra-slice information。

二、使用k个相邻slices预测出k个分割结果

在baseline模型的基础上,为了更好地利用2D切片之间的关联信息,(3D分割任务想要获得良好的性能,就要充分利用切片内部inter-slice information和切片之间intra-slice information),以一个坐标轴为例,得到在当前坐标轴上的每个2D切片之后,并不是将每个2D切片单独送入2D CNN segmentation network中,而是将当前切片与当前切片之后的k个切片都送入网络中(baseline模型中网络的输入通道数为1,而现在改进的模型中网络的输入通道数为k),这样对于L坐标轴下的L个切片,将会得到L个H*W*k张预测的segmentation map,其中会有重叠,在overlap的区域同样取平均值,然后再将L个切片的预测结果concatenate,则对于L,H,W轴将会得到3个3D CT scan的预测分割图(图中的每个体素点值表示体素点属于前景的概率),最终将3个预测图进行element wise 平均值,得到最终预测图。这样对于2D CNN分割网络输入的是k*H*W的图像,输出的预测也是k*H*W的分割概率图,以训练分割网络。但是这样也会带来信息损失,即输出预测概率图的某个slice通道应该更关注于跟它相关联的那个输入通道,也可以表现为预测某个slice的输出将会更关注于特征图的某个通道。在这里并没有加入这样的信息(slice sensitive information)

三、本文提出的模型:mini-group和channel wise attention

1. mini-group和postpone fusion

在(二)中的模型是将相邻slices的图像特征在input layer时就进行了融合,这样将不同slices的特征很早就混淆,则网络在预测某个slice时并不能有效利用与当前slices最为相关的那些特征,为了使得每个slice单独的特征能够被保留,则希望能够让不同的slice分别经过CNN网络提取特征之后,将所有slice的特征映射到相同的特征空间之后,在进行特征融合,而不是像(二),输入图像包含k个slice,则第一个卷积层的input channel=k,卷积操作实际上是将滑动窗口所对应的空间位置上的每个局部区域所有通道上的响应值加权求和,故而不同slice图像在第一个卷积层就进行了融合,这将导致在后续的操作中丢失slice sensitive information。对于切片l,获取l以及l后面的总共k个H*W的2D slices,然后将这k个slices分别送入2D CNN segmentation network,得到k个256*H*W的特征图,然后进行融合输出256*H*W的特征图(这个特征图将是对于预测第l个slice有利的信息),将融合后的特征图送入prediction layer输出对于第l个slice切片预测的segmentation prediction probability。然后L个切片的预测结果进行concatenate(这时候因为没有overlap,故而不用求平均值),得到H*W*L预测概率图,然后在3个坐标维度上的预测概率图进行平均,得到最终输出。

2.channel wise non-local attention module

(1)中只是将融合之后的特征图作为最终的预测结果,但是对于不同的slice切片,所需要的slice sensitive information不同,故而为了预测出每个slice的分割概率图,还需要针对于特定的切片,产生不同的融合信息,也就是说,可以将每个slice的特征图,k个融合之后的特征图进行attention操作,attention输出的特征图才是对于预测当前slice最为有用的信息,以输出对于当前切片的预测。

    将slice i的256*H*W特征图与所有group的输出特征图concatenate的(256*n)*H*W特征图进行attention计算。输出的结果图标用来预测slice i的前景概率图。