完整的ResNet代码

这篇文章是用来讲解(残差网络)代码的,结合代码理解残差网络结构 。
目录

Conv3×3
Conv1×1
代码解析
完整的代码:
可以直接调用torch内置的官方代码 。
from torchvision.models import resnet50model = resnet50()print("model:", model)
不论是调用还是,这些模型都是调用的模型 。因此我们仅需要看这个类就可以 。
在这个类中又由(瓶颈层)、3×3卷积层、1×1卷积层、组成 。接下来将逐步解释 。

拼劲层这个类在及之后的系列用这个,、用
参数说明:
=4:的输出通道数是输入通道数的4背
:输入通道数
:输出通道数
:步长
:下采样
:分组卷积
:卷积块宽度
:空洞卷积
:是否传入
class Bottleneck(nn.Module):# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)# while original implementation places the stride at the first 1x1 convolution(self.conv1)# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.# This variant is also known as ResNet V1.5 and improves accuracy according to# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.expansion = 4def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,base_width=64, dilation=1, norm_layer=None):super(Bottleneck, self).__init__()if norm_layer is None:norm_layer = nn.BatchNorm2dwidth = int(planes * (base_width / 64.)) * groups# Both self.conv2 and self.downsample layers downsample the input when stride != 1self.conv1 = conv1x1(inplanes, width)self.bn1 = norm_layer(width)self.conv2 = conv3x3(width, width, stride, groups, dilation)self.bn2 = norm_layer(width)self.conv3 = conv1x1(width, planes * self.expansion)self.bn3 = norm_layer(planes * self.expansion)self.relu = nn.ReLU(inplace=True)self.downsample = downsampleself.stride = stridedef forward(self, x):identity = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)# 表示如果输入和输出通道数不等,那就通过1x1卷积进行升维后的相加操作,否则可以可以直接相加if self.downsample is not None:identity = self.downsample(x)out += identityout = self.relu(out)'''if downsample:x -->conv_1x1-->bn-->relu-->conv_3x3-->bn-->relu-->conv_1x1-->bn--add-->relu-->out|___________downsample____________________________________________|else:x -->conv_1x1-->bn-->relu-->conv_3x3-->bn-->relu-->conv_1x1-->bn--add-->relu-->out|__________________________________________________________________|'''return out
下面这张图是一个结构图,残差边为一个1x1的卷积 。
Conv3×3
传入参数:
:输入通道
:输出通道
:步长
:卷积分组数
:可以控制空洞卷积

完整的ResNet代码

文章插图
可以看到这个conv3×3中的为3,bias为False,的大小和一样 。
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):"""3x3 convolution with padding"""return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,padding=dilation, groups=groups, bias=False, dilation=dilation)
Conv1×1
:输入通道数
:输出通道数
可以看到为1,bias为False
def conv1x1(in_planes, out_planes, stride=1):"""1x1 convolution"""return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
这个当在、用这个
传入参数:
:输入通道数
:输出通道数
:步长
:下采样