PyTorch图像分割模型——segmentation_models_pytor( 三 )

< valid_logs['iou_score']:max_score = valid_logs['iou_score']torch.save(model, './best_model.pth')print('Model saved!')if i == 25:optimizer.param_groups[0]['lr'] = 1e-5print('Decrease decoder learning rate to 1e-5!')
这里我们只训练 ‘car’ 的类别进行图像分割,设置好数据集后开始训练,训练场景如图所示 。
因为数据集量并不大,40轮次的训练很快就结束了,此时在文件夹中出现一个 .pth 的文件,即我们迭代40轮次训练后得到的最好的模型 。
3.2 UNet++测试
训练好后我们开始测试UNet++所训练的最好的那个模型,测试代码如下:
import osos.environ['CUDA_VISIBLE_DEVICES'] = '0'import numpy as npimport cv2import matplotlib.pyplot as pltimport albumentations as albuimport torchimport segmentation_models_pytorch as smpfrom torch.utils.data import Dataset as BaseDataset# ---------------------------------------------------------------### 加载数据class Dataset(BaseDataset):"""CamVid数据集 。进行图像读取,图像增强增强和图像预处理.Args:images_dir (str): 图像文件夹所在路径masks_dir (str): 图像分割的标签图像所在路径class_values (list): 用于图像分割的所有类别数augmentation (albumentations.Compose): 数据传输管道preprocessing (albumentations.Compose): 数据预处理"""# CamVid数据集中用于图像分割的所有标签类别CLASSES = ['sky', 'building', 'pole', 'road', 'pavement','tree', 'signsymbol', 'fence', 'car','pedestrian', 'bicyclist', 'unlabelled']def __init__(self,images_dir,masks_dir,classes=None,augmentation=None,preprocessing=None,):self.ids = os.listdir(images_dir)self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids]self.masks_fps = [os.path.join(masks_dir, image_id) for image_id in self.ids]# convert str names to class values on masksself.class_values = [self.CLASSES.index(cls.lower()) for cls in classes]self.augmentation = augmentationself.preprocessing = preprocessingdef __getitem__(self, i):# read dataimage = cv2.imread(self.images_fps[i])image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)mask = cv2.imread(self.masks_fps[i], 0)# 从标签中提取特定的类别 (e.g. cars)masks = [(mask == v) for v in self.class_values]mask = np.stack(masks, axis=-1).astype('float')# 图像增强应用if self.augmentation:sample = self.augmentation(image=image, mask=mask)image, mask = sample['image'], sample['mask']# 图像预处理应用if self.preprocessing:sample = self.preprocessing(image=image, mask=mask)image, mask = sample['image'], sample['mask']return image, maskdef __len__(self):return len(self.ids)# ---------------------------------------------------------------### 图像增强def get_validation_augmentation():"""调整图像使得图片的分辨率长宽能被32整除"""test_transform = [albu.PadIfNeeded(384, 480)]return albu.Compose(test_transform)def to_tensor(x, **kwargs):return x.transpose(2, 0, 1).astype('float32')def get_preprocessing(preprocessing_fn):"""进行图像预处理操作Args:preprocessing_fn (callbale): 数据规范化的函数(针对每种预训练的神经网络)Return:transform: albumentations.Compose"""_transform = [albu.Lambda(image=preprocessing_fn),albu.Lambda(image=to_tensor, mask=to_tensor),]return albu.Compose(_transform)# 图像分割结果的可视化展示def visualize(**images):"""PLot images in one row."""n = len(images)plt.figure(figsize=(16, 5))for i, (name, image) in enumerate(images.items()):plt.subplot(1, n, i + 1)plt.xticks([])plt.yticks([])plt.title(' '.join(name.split('_')).title())plt.imshow(image)plt.show()# ---------------------------------------------------------------if __name__ == '__main__':DATA_DIR = './data/CamVid/'# 测试集x_test_dir = os.path.join(DATA_DIR, 'test')y_test_dir = os.path.join(DATA_DIR, 'testannot')ENCODER = 'se_resnext50_32x4d'ENCODER_WEIGHTS = 'imagenet'CLASSES = ['car']ACTIVATION = 'sigmoid' # could be None for logits or 'softmax2d' for multiclass segmentationDEVICE = 'cuda'preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)# ---------------------------------------------------------------#$# 测试训练出来的最佳模型# 加载最佳模型best_model = torch.load('./best_model.pth')# 创建测试数据集test_dataset = Dataset(x_test_dir,y_test_dir,augmentation=get_validation_augmentation(),preprocessing=get_preprocessing(preprocessing_fn),classes=CLASSES,)# ---------------------------------------------------------------#$# 图像分割结果可视化展示# 对没有进行图像处理转化的测试集进行图像可视化展示test_dataset_vis = Dataset(x_test_dir, y_test_dir,classes=CLASSES,)# 从测试集中随机挑选3张图片进行测试for i in range(3):n = np.random.choice(len(test_dataset))image_vis = test_dataset_vis[n][0].astype('uint8')image, gt_mask = test_dataset[n]gt_mask = gt_mask.squeeze()x_tensor = torch.from_numpy(image).to(DEVICE).unsqueeze(0)pr_mask = best_model.predict(x_tensor)pr_mask = (pr_mask.squeeze().cpu().numpy().round())visualize(image=image_vis,ground_truth_mask=gt_mask,predicted_mask=pr_mask)