问题一: 机器学习的基本流程(20)

< best_loss:best_loss = torch.abs(loss).item()best_epoch = epochtorch.save(model.state_dict(), 'new_best_Deep_Ritz.mdl')# 连续绘图if epoch % 500 == 0:plt.ion()# 打开交互模式plt.close('all')model.load_state_dict(torch.load('new_best_Deep_Ritz.mdl'))with torch.no_grad():x1 = torch.linspace(-1, 1, 1001)x2 = torch.linspace(-1, 1, 1001)X, Y = torch.meshgrid(x1, x2)Z = torch.cat((Y.flatten()[:, None], Y.T.flatten()[:, None]), dim=1)Z = Z.to(device)pred = model(Z)plt.figure()pred = pred.cpu().numpy()pred = pred.reshape(1001, 1001)ax = plt.subplot(1, 1, 1)h = plt.imshow(pred, interpolation='nearest', cmap='rainbow',extent=[-1, 1, -1, 1],origin='lower', aspect='auto')plt.title("Training times:" + str(epoch))divider = make_axes_locatable(ax)cax = divider.append_axes("right", size="5%", pad=0.05)plt.colorbar(h, cax=cax)plt.savefig('./Training_process/Deep_Ritz_' + str(epoch) + '.png')if epoch == best_epoch:plt.savefig('Best_Deep_Ritz.png')# plt.show()# plt.pause(0.02)print('=' * 55)print('学习结束'.center(55))print('-' * 55)print('最优学习批次:', best_epoch, '最优误差:', best_loss)plt.close('all')plt.ioff()# 关闭交互模式plt.title('Error curve')plt.xlabel('loss vs. epoches')plt.ylabel('loss')plt.plot(range(0, epochs + 1), Loss_list, label='Loss')plt.savefig('Error_curve_Deep_Ritz.png')# plt.show()print('已生成"最优拟合结果图",请打开文件"Best_Deep_Ritz.png"查看')print('已生成"误差曲线图",请打开文件"Error_curve_Deep_Ritz.png"查看')print('-' * 55)print('准备绘制训练过程动态图')image2gif.image2gif('Deep_Ritz')print('=' * 55)
Open your !
参考书籍及文献李航.《统计学习方法》.清华大学出版社.周志华.《机器学习》.清华大学出版社.诸葛越.《百面机器学习算法工程师带你去面试》.人民邮电出版社(英) 塔里克 - 拉希德.《神经网络编程》.人民邮电出版社Deep(Part I)Deep(Part II)- A Deepfor文献解读-Deep (PINN)文献解读-物理信息深度学习(PINN)