CRNN,Opencv,Pytorch 字符识别( 二 )


3.数据读取
读取训练数据的代码如下所示:
import osimport torchimport cv2from torch.utils.data import Dataset, DataLoaderfrom torchvision import transformsfrom PIL import Imageimport numpy as npimport imgaug.augmenters as iaaclass CRNNDataSet(Dataset):def __init__(self, lines,train=True,img_width=100):super(CRNNDataSet, self).__init__()self.lines=linesself.train=trainself.img_width=img_widthself.T=img_width//4+1def __getitem__(self, index):image_path = self.lines[index].strip().split()[0]label = self.lines[index].strip().split()[1:]image = cv2.imread(image_path,0)# 图像预处理if self.train:image=self.get_random_data(image)else:image = cv2.resize(image,(self.img_width,32))# cv2.imshow('a21',image)# cv2.waitKey(0)# 标签格式转换为IntTensorlabel_max=np.ones(shape=(self.T),dtype=np.int32)*-1label = np.array([int(i) for i in label])label_max[0:len(label)]=label#归一化image=(image/255.).astype('float32')image=np.expand_dims(image,axis=0)image=torch.from_numpy(image)label_max=torch.from_numpy(label_max)return image, label_maxdef __len__(self):return len(self.lines)def get_random_data(self,img):"""随机增强图像"""seq = iaa.Sequential([iaa.Multiply((0.8, 1.3)),# change brightness, doesn't affect BBs(bounding boxes)iaa.GaussianBlur(sigma=(0, 1.0)),# 标准差为0到3之间的值iaa.Crop(percent=(0, 0.05)),iaa.Affine(scale=(0.95, 1.05),# 尺度变换rotate=(-4, 4),cval=(100,250),mode=iaa.ia.ALL),iaa.Resize({"height": 32, "width": self.img_width})])img=seq.augment(image=img)return imgif __name__ == '__main__':batch_size = 16lines=open('train.txt','r').readlines()trainData = http://www.kingceram.com/post/CRNNDataSet(lines=lines)trainLoader=DataLoader(dataset=trainData,batch_size=batch_size)for data, label in trainLoader:print(data.shape,label)
4.训练模型
训练代码如下所示:
from model import CRNNfrom mydataset import CRNNDataSetfrom torch.utils.data import DataLoaderimport torchfrom torch import optimfrom tqdm import tqdmimport numpy as npimport matplotlib.pyplot as pltdef decode(preds):char_set = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n','o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z','0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] + [" "]preds=list(preds)pred_text = ''for i,j in enumerate(preds):if j==n_class-1:continueif i==0:pred_text+=char_set[j]continueif preds[i-1]!=j:pred_text += char_set[j]return pred_textdef getAcc(preds,labs):acc=0char_set = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n','o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z','0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] + [" "]labs=labs.cpu().detach().numpy()preds = preds.cpu().detach().numpy()preds=np.argmax(preds,axis=-1)preds=np.transpose(preds,(1,0))out=[]for pred in preds:out_txt=decode(pred)out.append(out_txt)ll=[]for lab in labs:a=lab[lab!=-1]b=[char_set[i] for i in a]b="".join(b)ll.append(b)for a1,a2 in zip(out,ll):if a1==a2:acc+=1return acc/batch_sizebatch_size=128n_class = 37train_lines=open('train.txt','r').readlines()val_lines=open('val.txt','r').readlines()trainData = http://www.kingceram.com/post/CRNNDataSet(lines=train_lines,train=True,img_width=200)trainLoader = DataLoader(dataset=trainData, batch_size=batch_size, shuffle=True, num_workers=1)valData = CRNNDataSet(lines=val_lines,train=False,img_width=200)valLoader = DataLoader(dataset=valData, batch_size=batch_size, shuffle=False, num_workers=1)device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')net = CRNN(imgHeight=32, nChannel=1, nClass=n_class, nHidden=256)net=net.to(device)loss_func = torch.nn.CTCLoss(blank=n_class - 1)# 注意,这里的CTCLoss中的 blank是指空白字符的位置,在这里是第65个,也即最后一个optimizer = torch.optim.Adam(net.parameters(), lr=0.001, betas=(0.5, 0.999))#学习率衰减lr_scheduler= optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.99)#画图列表trainLoss=[]valLoss=[]trainAcc=[]valAcc=[]if __name__ == '__main__':#设置迭代次数200次Epoch=100epoch_step = len(train_lines) / batch_sizefor epoch in range(1, Epoch + 1):net.train()train_total_loss = 0val_total_loss=0train_total_acc = 0val_total_acc = 0with tqdm(total=epoch_step, desc=f'Epoch {epoch}/{Epoch}', postfix=dict, mininterval=0.3) as pbar:for step, (features, label) in enumerate(trainLoader, 1):labels = torch.IntTensor([])for j in range(label.size(0)):labels = torch.cat((labels, label[j]), 0)labels=labels[labels!=-1]features = features.to(device)labels = labels.to(device)loss_func=loss_func.to(device)batch_size = features.size()[0]out = net(features)log_probs = out.log_softmax(2).requires_grad_()targets = labelsinput_lengths = torch.IntTensor([out.size(0)] * int(out.size(1)))target_lengths = torch.where(label!=-1,1,0).sum(dim=-1)loss = loss_func(log_probs, targets, input_lengths, target_lengths)acc=getAcc(out,label)optimizer.zero_grad()loss.backward()optimizer.step()train_total_loss += losstrain_total_acc += accpbar.set_postfix(**{'loss': train_total_loss.item() / (step),'acc': train_total_acc/ (step), })pbar.update(1)trainLoss.append(train_total_loss.item()/step)trainAcc.append(train_total_acc/step)#保存模型torch.save(net.state_dict(), 'model.pth')#验证net.eval()for step, (features, label) in enumerate(valLoader, 1):with torch.no_grad():labels = torch.IntTensor([])for j in range(label.size(0)):labels = torch.cat((labels, label[j]), 0)labels = labels[labels != -1]features = features.to(device)labels = labels.to(device)loss_func = loss_func.to(device)batch_size = features.size()[0]out = net(features)log_probs = out.log_softmax(2).requires_grad_()targets = labelsinput_lengths = torch.IntTensor([out.size(0)] * int(out.size(1)))target_lengths = torch.where(label != -1, 1, 0).sum(dim=-1)loss = loss_func(log_probs, targets, input_lengths, target_lengths)acc = getAcc(out, label)val_total_loss+=lossval_total_acc+=accvalLoss.append(val_total_loss.item()/step)valAcc.append(val_total_acc/step)lr_scheduler.step()# print(trainLoss)# print(valLoss)"""绘制loss acc曲线图"""plt.figure()plt.plot(trainLoss, 'r')plt.plot(valLoss, 'b')plt.title('Training and validation loss')plt.xlabel("Epochs")plt.ylabel("Loss")plt.legend(["Loss", "Validation Loss"])plt.savefig('loss.png')plt.figure()plt.plot(trainAcc, 'r')plt.plot(valAcc, 'b')plt.title('Training and validation acc')plt.xlabel("Epochs")plt.ylabel("Acc")plt.legend(["Acc", "Validation Acc"])plt.savefig('acc.png')# plt.show()