五 CNN经典网络模型:ResNet简介及代码实现(PyTorch超详细注释版( 四 )


3. .py
import osimport jsonimport torchfrom PIL import Imagefrom torchvision import transformsimport matplotlib.pyplot as pltfrom model import resnet34def main():# 如果有NVIDA显卡,转到GPU训练,否则用CPUdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# 将多个transforms的操作整合在一起data_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])# 加载图片img_path = "../tulip.jpg"# 确定图片存在,否则反馈错误assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)img = Image.open(img_path)# imshow():对图像进行处理并显示其格式,show()则是将imshow()处理后的函数显示出来plt.imshow(img)# [C, H, W],转换图像格式img = data_transform(img)# [N, C, H, W],增加一个维度Nimg = torch.unsqueeze(img, dim=0)# 获取结果类型json_path = './class_indices.json'# 确定路径存在,否则反馈错误assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)# 读取内容with open(json_path, "r") as f:class_indict = json.load(f)# 模型实例化,将模型转到device,结果类型有5种model = resnet34(num_classes=5).to(device)# 载入模型权重weights_path = "./resNet34.pth"# 确定模型存在,否则反馈错误assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)# 加载训练好的模型参数model.load_state_dict(torch.load(weights_path, map_location=device))# 进入验证阶段model.eval()with torch.no_grad():# 预测类别# squeeze():维度压缩,返回一个tensor(张量),其中input中大小为1的所有维都已删除output = torch.squeeze(model(img.to(device))).cpu()# softmax:归一化指数函数,将预测结果输入进行非负性和归一化处理,最后将某一维度值处理为0-1之内的分类概率predict = torch.softmax(output, dim=0)# argmax(input):返回指定维度最大值的序号# .numpy():把tensor转换成numpy的格式predict_cla = torch.argmax(predict).numpy()# 输出的预测值与真实值print_res = "class: {}prob: {:.3}".format(class_indict[str(predict_cla)],predict[predict_cla].numpy())# 图片标题plt.title(print_res)for i in range(len(predict)):print("class: {:10}prob: {:.3}".format(class_indict[str(i)],predict[i].numpy()))plt.show()if __name__ == '__main__':main()
4. .py
import osfrom shutil import copy, rmtreeimport randomdef mk_file(file_path: str):if os.path.exists(file_path):# 如果文件夹存在,则先删除原文件夹再重新创建rmtree(file_path)os.makedirs(file_path)def main():# 保证随机可复现random.seed(0)# 将数据集中10%的数据划分到验证集中split_rate = 0.1# 指向解压后的flower_photos文件夹# getcwd():该函数不需要传递参数,获得当前所运行脚本的路径cwd = os.getcwd()# join():用于拼接文件路径,可以传入多个路径data_root = os.path.join(cwd, "flower_data")origin_flower_path = os.path.join(data_root, "flower_photos")# 确定路径存在,否则反馈错误assert os.path.exists(origin_flower_path), "path '{}' does not exist.".format(origin_flower_path)# isdir():判断某一路径是否为目录# listdir():返回指定的文件夹包含的文件或文件夹的名字的列表flower_class = [cla for cla in os.listdir(origin_flower_path)if os.path.isdir(os.path.join(origin_flower_path, cla))]# 创建训练集train文件夹,并由类名在其目录下创建子目录train_root = os.path.join(data_root, "train")mk_file(train_root)for cla in flower_class:# 建立每个类别对应的文件夹mk_file(os.path.join(train_root, cla))# 创建验证集val文件夹,并由类名在其目录下创建子目录val_root = os.path.join(data_root, "val")mk_file(val_root)for cla in flower_class:# 建立每个类别对应的文件夹mk_file(os.path.join(val_root, cla))# 遍历所有类别的图像并按比例分成训练集和验证集for cla in flower_class:cla_path = os.path.join(origin_flower_path, cla)# iamges列表存储了该目录下所有图像的名称images = os.listdir(cla_path)num = len(images)# 随机采样验证集的索引# 从images列表中随机抽取k个图像名称# random.sample:用于截取列表的指定长度的随机数,返回列表# eval_index保存验证集val的图像名称eval_index = random.sample(images, k=int(num*split_rate))for index, image in enumerate(images):if image in eval_index:# 将分配至验证集中的文件复制到相应目录image_path = os.path.join(cla_path, image)new_path = os.path.join(val_root, cla)copy(image_path, new_path)else:# 将分配至训练集中的文件复制到相应目录image_path = os.path.join(cla_path, image)new_path = os.path.join(train_root, cla)copy(image_path, new_path)# '\r'回车,回到当前行的行首,而不会换到下一行,如果接着输出,本行以前的内容会被逐一覆盖# end="":将print自带的换行用end中指定的str代替print("\r[{}] processing [{}/{}]".format(cla, index+1, num), end="")print()print("processing done!")if __name__ == '__main__':main()