88
99import numpy as np
1010import pandas as pd
11+ import qlib
1112import torch
1213import yaml
1314from qlib .backtest import Order
1718from qlib .rl .interpreter import ActionInterpreter , StateInterpreter
1819from qlib .rl .order_execution import SingleAssetOrderExecutionSimple
1920from qlib .rl .reward import Reward
20- from qlib .rl .trainer import Checkpoint , train
21+ from qlib .rl .trainer import Checkpoint , backtest , train
22+ from qlib .rl .trainer .callbacks import Callback , EarlyStopping , MetricsWriter
23+ from qlib .rl .utils .log import CsvWriter
2124from qlib .utils import init_instance_by_config
2225from tianshou .policy import BasePolicy
2326from torch import nn
@@ -98,40 +101,54 @@ def train_and_test(
98101 action_interpreter : ActionInterpreter ,
99102 policy : BasePolicy ,
100103 reward : Reward ,
104+ run_backtest : bool ,
101105) -> None :
106+ qlib .init ()
107+
102108 order_root_path = Path (data_config ["source" ]["order_dir" ])
103109
110+ data_granularity = simulator_config .get ("data_granularity" , 1 )
111+
104112 def _simulator_factory_simple (order : Order ) -> SingleAssetOrderExecutionSimple :
105113 return SingleAssetOrderExecutionSimple (
106114 order = order ,
107115 data_dir = Path (data_config ["source" ]["data_dir" ]),
108116 ticks_per_step = simulator_config ["time_per_step" ],
117+ data_granularity = data_granularity ,
109118 deal_price_type = data_config ["source" ].get ("deal_price_column" , "close" ),
110119 vol_threshold = simulator_config ["vol_limit" ],
111120 )
112121
113- train_dataset = LazyLoadDataset (
114- order_file_path = order_root_path / "train" ,
115- data_dir = Path ( data_config [ "source" ][ "data_dir" ]),
116- default_start_time_index = data_config [ "source" ][ "default_start_time" ],
117- default_end_time_index = data_config [ "source" ][ "default_end_time" ],
118- )
119- valid_dataset = LazyLoadDataset (
120- order_file_path = order_root_path / "valid" ,
121- data_dir = Path ( data_config ["source" ]["data_dir" ]) ,
122- default_start_time_index = data_config [ "source" ][ "default_start_time" ],
123- default_end_time_index = data_config [ "source" ][ "default_end_time" ],
124- )
122+ assert data_config [ "source" ][ "default_start_time_index" ] % data_granularity == 0
123+ assert data_config [ "source" ][ "default_end_time_index" ] % data_granularity == 0
124+
125+ train_dataset , valid_dataset , test_dataset = [
126+ LazyLoadDataset (
127+ order_file_path = order_root_path / tag ,
128+ data_dir = Path ( data_config [ "source" ][ "data_dir" ]),
129+ default_start_time_index = data_config [ "source" ][ "default_start_time_index" ] // data_granularity ,
130+ default_end_time_index = data_config ["source" ]["default_end_time_index" ] // data_granularity ,
131+ )
132+ for tag in ( "train" , "valid" , "test" )
133+ ]
125134
126- callbacks = []
127135 if "checkpoint_path" in trainer_config :
136+ callbacks : List [Callback ] = []
137+ callbacks .append (MetricsWriter (dirpath = Path (trainer_config ["checkpoint_path" ])))
128138 callbacks .append (
129139 Checkpoint (
130- dirpath = Path (trainer_config ["checkpoint_path" ]),
131- every_n_iters = trainer_config [ "checkpoint_every_n_iters" ] ,
140+ dirpath = Path (trainer_config ["checkpoint_path" ]) / "checkpoints" ,
141+ every_n_iters = trainer_config . get ( "checkpoint_every_n_iters" , 1 ) ,
132142 save_latest = "copy" ,
133143 ),
134144 )
145+ if "earlystop_patience" in trainer_config :
146+ callbacks .append (
147+ EarlyStopping (
148+ patience = trainer_config ["earlystop_patience" ],
149+ monitor = "val/pa" ,
150+ )
151+ )
135152
136153 trainer_kwargs = {
137154 "max_iters" : trainer_config ["max_epoch" ],
@@ -160,8 +177,21 @@ def _simulator_factory_simple(order: Order) -> SingleAssetOrderExecutionSimple:
160177 vessel_kwargs = vessel_kwargs ,
161178 )
162179
180+ if run_backtest :
181+ backtest (
182+ simulator_fn = _simulator_factory_simple ,
183+ state_interpreter = state_interpreter ,
184+ action_interpreter = action_interpreter ,
185+ initial_states = test_dataset ,
186+ policy = policy ,
187+ logger = CsvWriter (Path (trainer_config ["checkpoint_path" ])),
188+ reward = reward ,
189+ finite_env_type = trainer_kwargs ["finite_env_type" ],
190+ concurrency = trainer_kwargs ["concurrency" ],
191+ )
192+
163193
164- def main (config : dict ) -> None :
194+ def main (config : dict , run_backtest : bool ) -> None :
165195 if "seed" in config ["runtime" ]:
166196 seed_everything (config ["runtime" ]["seed" ])
167197
@@ -200,6 +230,7 @@ def main(config: dict) -> None:
200230 state_interpreter = state_interpreter ,
201231 policy = policy ,
202232 reward = reward ,
233+ run_backtest = run_backtest ,
203234 )
204235
205236
@@ -211,9 +242,10 @@ def main(config: dict) -> None:
211242
212243 parser = argparse .ArgumentParser ()
213244 parser .add_argument ("--config_path" , type = str , required = True , help = "Path to the config file" )
245+ parser .add_argument ("--run_backtest" , action = "store_true" , help = "Run backtest workflow after training is finished" )
214246 args = parser .parse_args ()
215247
216248 with open (args .config_path , "r" ) as input_stream :
217249 config = yaml .safe_load (input_stream )
218250
219- main (config )
251+ main (config , run_backtest = args . run_backtest )
0 commit comments