【人工智能笔记】第五节:基于TensorFlow 2( 四 )


6.训练及预测代码:
def main():# 输入数据维度,(年 月 日+(开盘 最高 最低 收盘)*3)input_num = 15# 预测数据维度output_num = 12batch_size = 20# 历史数据长度history_size = 30# 预测数据长度target_size = 5# 创建模型print('创建模型')gupiao_model = GuPiaoModel(output_num)# 加载数据print('加载数据')gupiao_loader = GuPiaoLoader()df_sh = gupiao_loader.load_one('./data/gupiao_data/999999.SH.csv')df_sz = gupiao_loader.load_one('./data/gupiao_data/399001.SZ.csv')df_target = gupiao_loader.load_one('./data/gupiao_data/XXXXXX.SH.csv')print('训练前预测')x, time_step = gupiao_loader.get_data_to_predict(df_sh, df_sz, df_target, history_size, target_size)print('x', x.shape, 'time_step', time_step.shape)y = gupiao_model.predict_jit(x, time_step, target_size)print('y', y.shape)# 显示预测值gupiao_loader.show_image(x[0,:,:], y[0,:,:])# 开始训练print('开始训练')gupiao_model.fit_generator(gupiao_loader.data_generator(df_sh, df_sz, df_target, batch_size, history_size, target_size),steps_per_epoch=int(len(df_target)/2),epochs=20, auto_save=True)# 预测print('预测')y = gupiao_model.predict_jit(x, time_step, target_size)# 显示预测值gupiao_loader.show_image(x[0,:,:], y[0,:,:])if __name__ == '__main__':main()