本项目用于基于盘口快照数据训练 XGBoost 模型,预测股票价格方向。
标签含义:
label_5label_10label_20label_40label_60
模型输出三分类结果:
0:下跌1:基本不变2:上涨
在项目根目录执行:
pip install -r requirements.txt数据目录中应放置 CSV 文件,文件名格式类似:
snapshot_sym1_date60_am.csv
snapshot_sym1_date60_pm.csv
脚本会匹配:
snapshot_sym*_date*_*.csv
CSV 中需要包含训练所需字段,例如:
symdatetimen_closen_midpriceamount_deltan_bid1到n_bid5n_ask1到n_ask5n_bsize1到n_bsize5n_asize1到n_asize5label_5、label_10、label_20、label_40、label_60
训练单个标签:
python train.py --data_dir ./your_data --model_save_dir ./Experiments --cache_dir ./cache --label label_5训练全部标签:
python train.py --data_dir ./your_data --model_save_dir ./Experiments --cache_dir ./cache --label all常用参数:
--data_dir:CSV 数据目录--model_save_dir:模型保存目录,默认./Experiments--cache_dir:缓存目录--label:训练标签,可选label_5、label_10、label_20、label_40、label_60、all--window:历史窗口大小,默认100--test_size:验证集比例,默认0.2--batch_size:缓存分块大小,默认5000--use_cache:使用已有缓存,不重新处理原始 CSV--resume:从最近 checkpoint 继续训练
使用已有缓存训练:
python train.py --data_dir ./your_data --model_save_dir ./Experiments --cache_dir ./cache --label label_5 --use_cache断点续训:
python train.py --data_dir ./your_data --model_save_dir ./Experiments --cache_dir ./cache --label label_5 --resume训练完成后,模型会保存为:
./Experiments/model_label_5.pth
checkpoint 会保存在:
./Experiments/checkpoints/
python test.py --model_path ./Experiments/model_label_5.pth --data_dir ./your_data --cache_dir ./cache --label label_5 --threshold 0.88该脚本会读取验证集缓存,输出不同置信度阈值下的信号比例、准确率、召回率和 PnL 统计。
python analyze_features.py --model_path ./Experiments/model_label_5.pth --data_dir ./your_data输出模型 Top 20 特征重要性,并保存图片到:
./Experiments/feature_importance_analysis.png
python count_max_pnl.py --label label_5 --cache_dir ./cache --top_percent 5只分析训练集:
python count_max_pnl.py --label label_5 --cache_dir ./cache --train_only只分析验证集:
python count_max_pnl.py --label label_5 --cache_dir ./cache --val_only该脚本基于缓存中的真实标签和未来价格变化,统计收益分布及前若干比例交易的 PnL。
python draw_candle.py --data_dir ./your_data --model_path ./Experiments/model_label_5.pth --threshold 0.88 --output_dir ./candle_plots常用参数:
--threshold:信号置信度阈值,默认0.88--interval_seconds:K 线周期,默认30--output_dir:图片输出目录,默认./candle_plots--window:特征窗口大小,默认100
图片会按股票代码保存到:
./candle_plots/sym*/date*.png
- 准备符合命名规则的 CSV 数据。
- 执行
train.py生成缓存并训练模型。 - 执行
test.py查看模型效果和阈值表现。 - 按需要执行
analyze_features.py、count_max_pnl.py或draw_candle.py做进一步分析。