以U-NET为例的网络构建代码实现

写在前面
最近在读U-Net论文时,网上看到从零构建网络模型的代码 。代码足够间接,而且结构比较完整,因此记录一下学习结果 。
本文重点在于如何代码的实现,对于U-Net论文中的细节未涉略,关于论文的讨论可移步 。
学习的资源链接在文章末尾 。
U-net模型
首先对于模型有一个简单的认识:
对于U-net模型的构建,主要的在于卷积层和转置卷积(下采样和上采样)的实现,以及如何实现镜像对应部分的连接 。请各位读者理解U-net模型,并且牢记每一步的通道数 。
代码实现
按照工业上或者竞赛上常见的解决问题的步骤,主要包括数据集的获取、模型的构建、模型的训练(损失函数的选择、模型的优化)、训练结果的验证 。因此接下来将从这几方面对代码进行解读 。
【以U-NET为例的网络构建代码实现】数据集的获取
数据集网址: Image|
百度网链接: 提取码:4t3y
其中需要读者根据自己的需求先训练集中分出部分的数据用作验证集 。
数据集的读取
import osfrom PIL import Imagefrom torch.utils.data import Datasetimport numpy as np
class CarvanaDataset(Dataset):def __init__(self, image_dir, mask_dir, transform=None):self.image_dir = image_dirself.mask_dir = mask_dirself.transform = transformself.images = os.listdir(image_dir)def __len__(self):return len(self.images)def __getitem__(self, idx):image_path = os.path.join(self.image_dir, self.images[idx])mask_path = os.path.join(self.mask_dir, self.images[idx].replace('.jpg', '_mask.gif'))image = np.array(Image.open(image_path).convert('RGB'))mask = np.array(Image.open(mask_path).convert('L'), dtype=np.float32)mask[mask == 255.0] = 1.0if self.transform is not None:augmentations = self.transform(image=image, mask=mask)image = augmentations['image']mask = augmentations['mask']return image, mask
为了后续操作比较方便,直接继承,然后返回image和对应的mask 。
os.(path) 返回指定路径下的文件(文件夹),在上面代码中,返回整个训练集图片对应的列表 。
os.path.join() 该操作直接获得每一张图片对应的储存路径
image.open().(), 该函数将图片按照指定的模式转变图片,例如RGB图像,或者灰度图像 。(具体的官方释义我还没找到,如果有官网的解释,请赐教)

以U-NET为例的网络构建代码实现

文章插图
mask[mask==255.0] = 1.0 方便后续的()函数的计算?(存疑)
模型的构建
首先观察U-net模型的构建,在pool层之前,总会有有两次卷积,将原图片的通道数增加 。
因此首先建立类 。
import torchimport torch.nn as nnimport torchvision.transforms.functional as TFclass DoubleConv(nn.Module):def __init__(self,in_channels,out_channels):super(DoubleConv,self).__init__()self.conv=nn.Sequential(nn.Conv2d(in_channels,out_channels,3,1,1,bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),nn.Conv2d(out_channels,out_channels,3,1,1,bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),)def forward(self,x):return self.conv(x)
下一步,观察U-net模型,由于具有对称性显得格外的优雅,而且每一步的处理显得很有规律,正是因为有这样的规律,因此我们在写代码的时候可以不那么繁琐,重复的卷积-池化-卷积-池化 。
class UNET(nn.Module):def __init__(self, in_channels=3, out_channels=1, features=[64,128,256,512]):super(UNET,self).__init__()self.ups = nn.ModuleList()self.downs = nn.ModuleList()self.pool = nn.MaxPool2d(kernel_size=2,stride=2)for feature in features:self.downs.append(DoubleConv(in_channels,feature))in_channels = featurefor feature in reversed(features):self.ups.append(nn.ConvTranspose2d(feature*2,feature,kernel_size=2,stride=2))self.ups.append(DoubleConv(feature*2, feature))self.bottleneck = DoubleConv(features[-1], features[-1]*2)self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)