PyTorch图像分割模型——segmentation_models_pytor( 二 )


授人以鱼不如授人以渔,放出了9种模型还不算,俄罗斯大神还贴心地提供了如何使用数据集进行训练的示例(详情可见在中的/cars().ipynb) 。
数据集是计算机视觉领域常用的一个数据集,拥有汽车、建筑、人行道、云、树等12种不同的类别,通常用来进行街景分割 。方便起见,我们从上拿来一套已经标注好的图像分割的数据集来进行训练和测试,下载地址:数据集
这里尤其要注意的是我们训练所用的数据集,不管是原图还是含有标注信息的图像,都必须是png格式,因为png图像可以做到无损压缩,能在保证最不失真的情况下尽可能压缩图像文件的大小 。我们从数据集中所下载的数据也都是png格式的 。
包含9种用于二分类和多类分割的模型架构,这里我们只采用UNet++,并使用了这个预训练骨干模型 。
训练UNet++模型的代码如下所示:
import osos.environ['CUDA_VISIBLE_DEVICES'] = '0'import numpy as npimport cv2import matplotlib.pyplot as pltimport albumentations as albuimport torchimport segmentation_models_pytorch as smpfrom torch.utils.data import DataLoaderfrom torch.utils.data import Dataset as BaseDataset# ---------------------------------------------------------------### 加载数据class Dataset(BaseDataset):"""CamVid数据集 。进行图像读取,图像增强增强和图像预处理.Args:images_dir (str): 图像文件夹所在路径masks_dir (str): 图像分割的标签图像所在路径class_values (list): 用于图像分割的所有类别数augmentation (albumentations.Compose): 数据传输管道preprocessing (albumentations.Compose): 数据预处理"""# CamVid数据集中用于图像分割的所有标签类别CLASSES = ['sky', 'building', 'pole', 'road', 'pavement','tree', 'signsymbol', 'fence', 'car','pedestrian', 'bicyclist', 'unlabelled']def __init__(self,images_dir,masks_dir,classes=None,augmentation=None,preprocessing=None,):self.ids = os.listdir(images_dir)self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids]self.masks_fps = [os.path.join(masks_dir, image_id) for image_id in self.ids]# convert str names to class values on masksself.class_values = [self.CLASSES.index(cls.lower()) for cls in classes]self.augmentation = augmentationself.preprocessing = preprocessingdef __getitem__(self, i):# read dataimage = cv2.imread(self.images_fps[i])image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)mask = cv2.imread(self.masks_fps[i], 0)# 从标签中提取特定的类别 (e.g. cars)masks = [(mask == v) for v in self.class_values]mask = np.stack(masks, axis=-1).astype('float')# 图像增强应用if self.augmentation:sample = self.augmentation(image=image, mask=mask)image, mask = sample['image'], sample['mask']# 图像预处理应用if self.preprocessing:sample = self.preprocessing(image=image, mask=mask)image, mask = sample['image'], sample['mask']return image, maskdef __len__(self):return len(self.ids)# ---------------------------------------------------------------### 图像增强def get_training_augmentation():train_transform = [albu.HorizontalFlip(p=0.5),albu.ShiftScaleRotate(scale_limit=0.5, rotate_limit=0, shift_limit=0.1, p=1, border_mode=0),albu.PadIfNeeded(min_height=320, min_width=320, always_apply=True, border_mode=0),albu.RandomCrop(height=320, width=320, always_apply=True),albu.IAAAdditiveGaussianNoise(p=0.2),albu.IAAPerspective(p=0.5),albu.OneOf([albu.CLAHE(p=1),albu.RandomBrightness(p=1),albu.RandomGamma(p=1),],p=0.9,),albu.OneOf([albu.IAASharpen(p=1),albu.Blur(blur_limit=3, p=1),albu.MotionBlur(blur_limit=3, p=1),],p=0.9,),albu.OneOf([albu.RandomContrast(p=1),albu.HueSaturationValue(p=1),],p=0.9,),]return albu.Compose(train_transform)def get_validation_augmentation():"""调整图像使得图片的分辨率长宽能被32整除"""test_transform = [albu.PadIfNeeded(384, 480)]return albu.Compose(test_transform)def to_tensor(x, **kwargs):return x.transpose(2, 0, 1).astype('float32')def get_preprocessing(preprocessing_fn):"""进行图像预处理操作Args:preprocessing_fn (callbale): 数据规范化的函数(针对每种预训练的神经网络)Return:transform: albumentations.Compose"""_transform = [albu.Lambda(image=preprocessing_fn),albu.Lambda(image=to_tensor, mask=to_tensor),]return albu.Compose(_transform)#$# 创建模型并训练# ---------------------------------------------------------------if __name__ == '__main__':# 数据集所在的目录DATA_DIR = './data/CamVid/'# 如果目录下不存在CamVid数据集,则克隆下载if not os.path.exists(DATA_DIR):print('Loading data...')os.system('git clone https://github.com/alexgkendall/SegNet-Tutorial ./data')print('Done!')# 训练集x_train_dir = os.path.join(DATA_DIR, 'train')y_train_dir = os.path.join(DATA_DIR, 'trainannot')# 验证集x_valid_dir = os.path.join(DATA_DIR, 'val')y_valid_dir = os.path.join(DATA_DIR, 'valannot')ENCODER = 'se_resnext50_32x4d'ENCODER_WEIGHTS = 'imagenet'CLASSES = ['car']ACTIVATION = 'sigmoid' # could be None for logits or 'softmax2d' for multiclass segmentationDEVICE = 'cuda'# 用预训练编码器建立分割模型# 使用FPN模型# model = smp.FPN(#encoder_name=ENCODER,#encoder_weights=ENCODER_WEIGHTS,#classes=len(CLASSES),#activation=ACTIVATION,# )# 使用unet++模型model = smp.UnetPlusPlus(encoder_name=ENCODER,encoder_weights=ENCODER_WEIGHTS,classes=len(CLASSES),activation=ACTIVATION,)preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)# 加载训练数据集train_dataset = Dataset(x_train_dir,y_train_dir,augmentation=get_training_augmentation(),preprocessing=get_preprocessing(preprocessing_fn),classes=CLASSES,)# 加载验证数据集valid_dataset = Dataset(x_valid_dir,y_valid_dir,augmentation=get_validation_augmentation(),preprocessing=get_preprocessing(preprocessing_fn),classes=CLASSES,)# 需根据显卡的性能进行设置,batch_size为每次迭代中一次训练的图片数,num_workers为训练时的工作进程数,如果显卡不太行或者显存空间不够,将batch_size调低并将num_workers调为0train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=0)valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=False, num_workers=0)loss = smp.utils.losses.DiceLoss()metrics = [smp.utils.metrics.IoU(threshold=0.5),]optimizer = torch.optim.Adam([dict(params=model.parameters(), lr=0.0001),])# 创建一个简单的循环,用于迭代数据样本train_epoch = smp.utils.train.TrainEpoch(model,loss=loss,metrics=metrics,optimizer=optimizer,device=DEVICE,verbose=True,)valid_epoch = smp.utils.train.ValidEpoch(model,loss=loss,metrics=metrics,device=DEVICE,verbose=True,)# 进行40轮次迭代的模型训练max_score = 0for i in range(0, 40):print('\nEpoch: {}'.format(i))train_logs = train_epoch.run(train_loader)valid_logs = valid_epoch.run(valid_loader)# 每次迭代保存下训练最好的模型if max_score