以U-NET为例的网络构建代码实现( 二 )


上述代码中,将整个u-Net模型分为卷积层(下采样)、转置卷积层(上采样)、池化层、瓶颈层以及最后的卷积层 。
对于下采样阶段,使用(),然后确定每一次卷积的输入、输出通道,然后使用循环结构 。
features=[64,128,256,512]self.downs = nn.ModuleList()for feature in features:self.downs.append(DoubleConv(in_channels,feature))in_channels = feature
def forward(self,x):skip_connections = []for down in self.downs:x = down(x)skip_connections.append(x)x = self.pool(x)x = self.bottleneck(x)skip_connections = skip_connections[::-1]for idx in range(0, len(self.ups) ,2):x = self.ups[idx](x)skip_connection = skip_connections[idx//2]if x.shape != skip_connection.shape:x=TF.resize(x,size=skip_connection.shape[2:])concat_skip = torch.cat((skip_connection,x), dim=1)x = self.ups[idx+1](concat_skip)return self.final_conv(x)
在前向传播的时候,需要注意的是,U-net每一层都有一个skip—
skip-=[] ,将经过卷积的x保存到列表中,在上采样的时候进行连接
=[::-1], 保存顺序与使用顺序相反,因此需要反序
=torch.cat((, x),dim=1)对两者进行连接
一些实用操作
我觉得我们在写代码的时候,为什么代码结构看的比较凌乱,主要因为我们没有能够将每一个功能、操作整合起来,下面给一个具体的例子 。
def save_checkpoint(state,filename='my_checkpoint.pth.tar'):print('=>Saving checkpoint')torch.save(state, filename)
将训练模型保存起来的函数
torch.save()官网torch.save()注释
def load_checkpoint(checkpoint, model):print('=>Loading checkpoint')model.load_state_dict(checkpoint['state_dict'])

以U-NET为例的网络构建代码实现

文章插图
加载模型,可以将上次未训练完的模型再次进行训练
def get_loader(train_dir,train_maskdir,val_dir,val_maskdir,batch_size,train_transform,val_transform,num_workers=1,pin_momory=True,):train_ds = CarvanaDataset(image_dir=train_dir,mask_dir=train_maskdir,transform=train_transform)train_loader = DataLoader(train_ds,batch_size=batch_size,num_workers=num_workers,pin_memory=pin_momory,shuffle=True)val_ds = CarvanaDataset(image_dir=val_dir,mask_dir=val_maskdir,transform=val_transform)val_loader = DataLoader(val_ds,batch_size=batch_size,num_workers=num_workers,pin_memory=pin_momory,shuffle=False)return train_loader,val_loader
加载数据的常用函数,其中 自定义,也可以直接使用()
()函数中参数:
训练模型
超参数的确定:
LEARNING_RATE = 1e-4DEVICE = "cuda" if torch.cuda.is_available() else "cpu"IMAGE_HEIGHT = 160IMAGE_WIDTH = 240BATCH_SIZE = 16NUM_EPOCHS = 3NUM_WORKER = 2PIN_MEMORY = TrueLOAD_MODEL = FalseTRAIN_IMG_DIR = "data/train/"TRAIN_MASK_DIR = "data/train_masks/"VAL_IMG_DIR = "data/val/"VAL_MASK_DIR = "data/val_masks/"
训练函数()
def train_fn(loader, model, optimizer, loss_fn, scaler):loop = tqdm(loader)for batch_idx, (data, targets) in enumerate(loop):data = http://www.kingceram.com/post/data.to(device=DEVICE)targets = targets.float().unsqueeze(1).to(device=DEVICE)#forward'''混合精度训练'''with torch.cuda.amp.autocast():preds = model(data)loss = loss_fn(preds,targets)#backwardoptimizer.zero_grad()scaler.scale(loss).backward()scaler.step(optimizer)scaler.update()#update tqdm looploop.set_postfix(loss=loss.item())
loop = tqdm() 简单理解为快速、可扩展的进度条
loop.() 设置进度条的输出内容