前面的文章已经介绍,将短线个股挖掘问题转化为深度学习处理的分类问题,并且已经完成训练,将训练得到的模型保存到本地。本文将记录如何使用Keras加载模型并进行预测的过程。
首先,找到训练模型保存的目录,加载模型:
# 加载模型 loaded_model = keras.models.load_model('./model/{}'.format(stk_code))然后,读入数据,将数据转化为字典类型作为预测所使用的输入字典,键为特征的索引,值为tensor。我们使用了220个特征,索引值依次为0至219。
# 读入数据 data_file = './baostock/prediction_data_pre/{}.csv'.format(stk_code) in_df = pd.read_csv(data_file) # 预测用的输入字典 temp_dict = {} # 将数据导入输入字典 for i in range(in_df.shape[1]): temp_dict[i] = in_df['{}'.format(i)].tolist() input_dict = {name: tf.convert_to_tensor(value) for name, value in temp_dict.items()}接着,调用模型的predict方法进行预测,将预测结果保存到列表results中。
# 进行预测 predictions = loaded_model.predict(input_dict) results = [] for i in range(in_df.shape[0]): results.append(predictions[i][0])然后,我们在未来用于回测的数据后添加一列predict_result,并保存到本地。这样backtrader就可以通过加载本地文件,完成基于深度学习的回测。
# 输出到文件 data_file = './baostock/data_ext/{}.csv'.format(stk_code) out_df = pd.read_csv(data_file) out_df = out_df[(out_df['date'] > '2017-12-31') & (out_df['date'] <= '2020-06-30')] out_df['predict_result'] = results out_df.to_csv('./baostock/predict_results/{}res.csv'.format(stk_code), index = False)最后,还是记得在每只股票完成预测后,清理内存,以防内存被刷爆。
# 清理内存 backend.clear_session()以上就完成了加载本地模型进行预测的过程,完整代码如下。下一篇文章将记录如果使用预测结果,进行多股回测。
import tensorflow as tf import numpy as np import pandas as pd import os from tensorflow import keras from tensorflow.keras import layers from tensorflow.keras import backend stk_code_file = './stk_data/dp_stock_list.csv' stk_list = pd.read_csv(stk_code_file)['code'].tolist() for stk_code in stk_list: print('processing {} ...'.format(stk_code)) # 加载模型 loaded_model = keras.models.load_model('./model/{}'.format(stk_code)) # 读入数据 data_file = './baostock/prediction_data_pre/{}.csv'.format(stk_code) in_df = pd.read_csv(data_file) # 预测用的输入字典 temp_dict = {} # 将数据导入输入字典 for i in range(in_df.shape[1]): temp_dict[i] = in_df['{}'.format(i)].tolist() input_dict = {name: tf.convert_to_tensor(value) for name, value in temp_dict.items()} # 进行预测 predictions = loaded_model.predict(input_dict) results = [] for i in range(in_df.shape[0]): results.append(predictions[i][0]) # 输出到文件 data_file = './baostock/data_ext/{}.csv'.format(stk_code) out_df = pd.read_csv(data_file) out_df = out_df[(out_df['date'] > '2017-12-31') & (out_df['date'] <= '2020-06-30')] out_df['predict_result'] = results out_df.to_csv('./baostock/predict_results/{}res.csv'.format(stk_code), index = False) # 清理内存 backend.clear_session()欢迎大家关注、点赞、转发、留言,感谢支持! 为了便于相互交流学习,已建微信群,感兴趣的读者请加微信。