Unet图像分割在PyTorch上的实现

Unet图像分割在PyTorch上的实现

Unet是一个最近比较火的网络结构。它的理论已经有很多大佬在讨论了。本文主要从实际操作的层面,讲解如何使用pytorch实现unet图像分割。

通常我会在粗略了解某种方法之后,就进行实际操作。在操作过程中,也许会遇到一些疑问,再回过头去仔细研究某个理论。这样的学习方法,是我比较喜欢的方式。这也是fast.ai推崇的自上而下的学习方式。

本文将先简单介绍Unet的理论基础,然后使用pytorch一步一步地实现Unet图像分割。因为主要目的是提供一个baseline模型给大家,所以代码主要关注在如何构造Unet的网络结构。

当你学会了如何用代码实现Unet,我相信你对Unet的理解已经比较深刻了。

本文完整的代码:github.com/Qiuyan918/Un


Unet

图1: Unet的网络结构

Unet主要用于图像分割问题。图1是Unet论文中的网络结构图。可以看出Unet是一个对称的结构,左半边是Encoder,右半边是Decoder。图像会先经过Encoder处理,再经过Decoder处理,最终实现图像分割。它们分别的作用如下:

  • Encoder:使得模型理解了图像的内容,但是丢弃了图像的位置信息。
  • Decoder:使模型结合Encoder对图像内容的理解,恢复图像的位置信息。

Encoder的部分和传统的网络结构类似,可以选择图中的结构,也可以选择VGG,ResNet等。随着卷积层的加深,特征图的长宽减小,通道增加。虽然Encoder提取了图像的高级特征,但是丢弃了图像的位置信息。所以在图像识别问题中,模型只需要Encoder的部分。因为图像识别不需要位置信息,只需要提取图像的内容信息。

Decoder的部分是Unet的重点。Decoder中涉及upconvolution这个概念。关于upconvolution,这里不做详细介绍,简单来说就是convolution的反向运算。Decoder的每一层都通过upconvolution(图中绿色箭头),并且和Encoder相对应的初级特征结合(图中的灰色箭头),逐渐恢复图像的位置信息。在Decoder中,随着卷积层的加深,特征图的长宽增大,通道减少。


数据:

图2: Kaggle盐体分割比赛

本文用到的数据来源于Kaggle盐体分割比赛。这次比赛的问题是一个非常典型的图像分割问题。比赛中的大佬们基本上都用的Unet。

我们的目标就是将图片中的盐体找出来。盐体有一些我不太懂的经济价值,反正是很有意义的

编辑于 2022-09-12 20:45