Skip to content

Commit d8fc9ae

Browse files
authored
RL Training pipeline on 5-min data (#1415)
* Workflow runnable * CI * Slight changes to make the workflow runnable. The changes of handler/provider should be reverted before merging. * Train experiment successful * Refine handler & provider * CI issues * Resolve PR comments * Resolve PR comments * CI issues * Fix test issue * Black
1 parent d876466 commit d8fc9ae

File tree

9 files changed

+155
-59
lines changed

9 files changed

+155
-59
lines changed

qlib/contrib/data/highfreq_handler.py

Lines changed: 32 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,11 @@ def __init__(
113113
fit_end_time=None,
114114
drop_raw=True,
115115
day_length=240,
116+
freq="1min",
117+
columns=["$open", "$high", "$low", "$close", "$vwap"],
116118
):
117119
self.day_length = day_length
120+
self.columns = columns
118121

119122
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
120123
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
@@ -124,7 +127,7 @@ def __init__(
124127
"kwargs": {
125128
"config": self.get_feature_config(),
126129
"swap_level": False,
127-
"freq": "1min",
130+
"freq": freq,
128131
},
129132
}
130133
super().__init__(
@@ -160,19 +163,13 @@ def get_normalized_price_feature(price_field, shift=0):
160163
)
161164
return feature_ops
162165

163-
fields += [get_normalized_price_feature("$open", 0)]
164-
fields += [get_normalized_price_feature("$high", 0)]
165-
fields += [get_normalized_price_feature("$low", 0)]
166-
fields += [get_normalized_price_feature("$close", 0)]
167-
fields += [get_normalized_price_feature("$vwap", 0)]
168-
names += ["$open", "$high", "$low", "$close", "$vwap"]
166+
for column_name in self.columns:
167+
fields.append(get_normalized_price_feature(column_name, 0))
168+
names.append(column_name)
169169

170-
fields += [get_normalized_price_feature("$open", self.day_length)]
171-
fields += [get_normalized_price_feature("$high", self.day_length)]
172-
fields += [get_normalized_price_feature("$low", self.day_length)]
173-
fields += [get_normalized_price_feature("$close", self.day_length)]
174-
fields += [get_normalized_price_feature("$vwap", self.day_length)]
175-
names += ["$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1"]
170+
for column_name in self.columns:
171+
fields.append(get_normalized_price_feature(column_name, self.day_length))
172+
names.append(column_name + "_1")
176173

177174
# calculate and fill nan with 0
178175
fields += [
@@ -258,14 +255,17 @@ def __init__(
258255
start_time=None,
259256
end_time=None,
260257
day_length=240,
258+
freq="1min",
259+
columns=["$close", "$vwap", "$volume"],
261260
):
262261
self.day_length = day_length
262+
self.columns = set(columns)
263263
data_loader = {
264264
"class": "QlibDataLoader",
265265
"kwargs": {
266266
"config": self.get_feature_config(),
267267
"swap_level": False,
268-
"freq": "1min",
268+
"freq": freq,
269269
},
270270
}
271271
super().__init__(
@@ -279,21 +279,24 @@ def get_feature_config(self):
279279
fields = []
280280
names = []
281281

282-
template_paused = f"Cut({{0}}, {self.day_length * 2}, None)"
283-
template_fillnan = "FFillNan({0})"
284-
template_if = "If(IsNull({1}), {0}, {1})"
285-
fields += [
286-
template_paused.format(template_fillnan.format("$close")),
287-
]
288-
names += ["$close0"]
289-
290-
fields += [
291-
template_paused.format(template_if.format(template_fillnan.format("$close"), "$vwap")),
292-
]
293-
names += ["$vwap0"]
294-
295-
fields += [template_paused.format("If(IsNull({0}), 0, {0})".format("$volume"))]
296-
names += ["$volume0"]
282+
if "$close" in self.columns:
283+
template_paused = f"Cut({{0}}, {self.day_length * 2}, None)"
284+
template_fillnan = "FFillNan({0})"
285+
template_if = "If(IsNull({1}), {0}, {1})"
286+
fields += [
287+
template_paused.format(template_fillnan.format("$close")),
288+
]
289+
names += ["$close0"]
290+
291+
if "$vwap" in self.columns:
292+
fields += [
293+
template_paused.format(template_if.format(template_fillnan.format("$close"), "$vwap")),
294+
]
295+
names += ["$vwap0"]
296+
297+
if "$volume" in self.columns:
298+
fields += [template_paused.format("If(IsNull({0}), 0, {0})".format("$volume"))]
299+
names += ["$volume0"]
297300

298301
return fields, names
299302

qlib/contrib/data/highfreq_provider.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def __init__(
2828
feature_conf: dict,
2929
label_conf: Optional[dict] = None,
3030
backtest_conf: dict = None,
31+
freq: str = "1min",
3132
**kwargs,
3233
) -> None:
3334
self.start_time = start_time
@@ -42,6 +43,7 @@ def __init__(
4243
self.backtest_conf = backtest_conf
4344
self.qlib_conf = qlib_conf
4445
self.logger = get_module_logger("HighFreqProvider")
46+
self.freq = freq
4547

4648
def get_pre_datasets(self):
4749
"""Generate the training, validation and test datasets for prediction
@@ -116,8 +118,8 @@ def _prepare_calender_cache(self):
116118
# This code used the copy-on-write feature of Linux
117119
# to avoid calculating the calendar multiple times in the subprocess.
118120
# This code may accelerate, but may be not useful on Windows and Mac Os
119-
Cal.calendar(freq="1min")
120-
get_calendar_day(freq="1min")
121+
Cal.calendar(freq=self.freq)
122+
get_calendar_day(freq=self.freq)
121123

122124
def _gen_dataframe(self, config, datasets=["train", "valid", "test"]):
123125
try:
@@ -240,7 +242,7 @@ def _gen_day_dataset(self, config, conf_type):
240242
with open(path + "tmp_dataset.pkl", "rb") as f:
241243
new_dataset = pkl.load(f)
242244

243-
time_list = D.calendar(start_time=self.start_time, end_time=self.end_time, freq="1min")[::240]
245+
time_list = D.calendar(start_time=self.start_time, end_time=self.end_time, freq=self.freq)[::240]
244246

245247
def generate_dataset(times):
246248
if os.path.isfile(path + times.strftime("%Y-%m-%d") + ".pkl"):
@@ -283,7 +285,7 @@ def _gen_stock_dataset(self, config, conf_type):
283285

284286
instruments = D.instruments(market="all")
285287
stock_list = D.list_instruments(
286-
instruments=instruments, start_time=self.start_time, end_time=self.end_time, freq="1min", as_list=True
288+
instruments=instruments, start_time=self.start_time, end_time=self.end_time, freq=self.freq, as_list=True
287289
)
288290

289291
def generate_dataset(stock):

qlib/rl/contrib/train_onpolicy.py

Lines changed: 50 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import numpy as np
1010
import pandas as pd
11+
import qlib
1112
import torch
1213
import yaml
1314
from qlib.backtest import Order
@@ -17,7 +18,9 @@
1718
from qlib.rl.interpreter import ActionInterpreter, StateInterpreter
1819
from qlib.rl.order_execution import SingleAssetOrderExecutionSimple
1920
from qlib.rl.reward import Reward
20-
from qlib.rl.trainer import Checkpoint, train
21+
from qlib.rl.trainer import Checkpoint, backtest, train
22+
from qlib.rl.trainer.callbacks import Callback, EarlyStopping, MetricsWriter
23+
from qlib.rl.utils.log import CsvWriter
2124
from qlib.utils import init_instance_by_config
2225
from tianshou.policy import BasePolicy
2326
from torch import nn
@@ -98,40 +101,54 @@ def train_and_test(
98101
action_interpreter: ActionInterpreter,
99102
policy: BasePolicy,
100103
reward: Reward,
104+
run_backtest: bool,
101105
) -> None:
106+
qlib.init()
107+
102108
order_root_path = Path(data_config["source"]["order_dir"])
103109

110+
data_granularity = simulator_config.get("data_granularity", 1)
111+
104112
def _simulator_factory_simple(order: Order) -> SingleAssetOrderExecutionSimple:
105113
return SingleAssetOrderExecutionSimple(
106114
order=order,
107115
data_dir=Path(data_config["source"]["data_dir"]),
108116
ticks_per_step=simulator_config["time_per_step"],
117+
data_granularity=data_granularity,
109118
deal_price_type=data_config["source"].get("deal_price_column", "close"),
110119
vol_threshold=simulator_config["vol_limit"],
111120
)
112121

113-
train_dataset = LazyLoadDataset(
114-
order_file_path=order_root_path / "train",
115-
data_dir=Path(data_config["source"]["data_dir"]),
116-
default_start_time_index=data_config["source"]["default_start_time"],
117-
default_end_time_index=data_config["source"]["default_end_time"],
118-
)
119-
valid_dataset = LazyLoadDataset(
120-
order_file_path=order_root_path / "valid",
121-
data_dir=Path(data_config["source"]["data_dir"]),
122-
default_start_time_index=data_config["source"]["default_start_time"],
123-
default_end_time_index=data_config["source"]["default_end_time"],
124-
)
122+
assert data_config["source"]["default_start_time_index"] % data_granularity == 0
123+
assert data_config["source"]["default_end_time_index"] % data_granularity == 0
124+
125+
train_dataset, valid_dataset, test_dataset = [
126+
LazyLoadDataset(
127+
order_file_path=order_root_path / tag,
128+
data_dir=Path(data_config["source"]["data_dir"]),
129+
default_start_time_index=data_config["source"]["default_start_time_index"] // data_granularity,
130+
default_end_time_index=data_config["source"]["default_end_time_index"] // data_granularity,
131+
)
132+
for tag in ("train", "valid", "test")
133+
]
125134

126-
callbacks = []
127135
if "checkpoint_path" in trainer_config:
136+
callbacks: List[Callback] = []
137+
callbacks.append(MetricsWriter(dirpath=Path(trainer_config["checkpoint_path"])))
128138
callbacks.append(
129139
Checkpoint(
130-
dirpath=Path(trainer_config["checkpoint_path"]),
131-
every_n_iters=trainer_config["checkpoint_every_n_iters"],
140+
dirpath=Path(trainer_config["checkpoint_path"]) / "checkpoints",
141+
every_n_iters=trainer_config.get("checkpoint_every_n_iters", 1),
132142
save_latest="copy",
133143
),
134144
)
145+
if "earlystop_patience" in trainer_config:
146+
callbacks.append(
147+
EarlyStopping(
148+
patience=trainer_config["earlystop_patience"],
149+
monitor="val/pa",
150+
)
151+
)
135152

136153
trainer_kwargs = {
137154
"max_iters": trainer_config["max_epoch"],
@@ -160,8 +177,21 @@ def _simulator_factory_simple(order: Order) -> SingleAssetOrderExecutionSimple:
160177
vessel_kwargs=vessel_kwargs,
161178
)
162179

180+
if run_backtest:
181+
backtest(
182+
simulator_fn=_simulator_factory_simple,
183+
state_interpreter=state_interpreter,
184+
action_interpreter=action_interpreter,
185+
initial_states=test_dataset,
186+
policy=policy,
187+
logger=CsvWriter(Path(trainer_config["checkpoint_path"])),
188+
reward=reward,
189+
finite_env_type=trainer_kwargs["finite_env_type"],
190+
concurrency=trainer_kwargs["concurrency"],
191+
)
192+
163193

164-
def main(config: dict) -> None:
194+
def main(config: dict, run_backtest: bool) -> None:
165195
if "seed" in config["runtime"]:
166196
seed_everything(config["runtime"]["seed"])
167197

@@ -200,6 +230,7 @@ def main(config: dict) -> None:
200230
state_interpreter=state_interpreter,
201231
policy=policy,
202232
reward=reward,
233+
run_backtest=run_backtest,
203234
)
204235

205236

@@ -211,9 +242,10 @@ def main(config: dict) -> None:
211242

212243
parser = argparse.ArgumentParser()
213244
parser.add_argument("--config_path", type=str, required=True, help="Path to the config file")
245+
parser.add_argument("--run_backtest", action="store_true", help="Run backtest workflow after training is finished")
214246
args = parser.parse_args()
215247

216248
with open(args.config_path, "r") as input_stream:
217249
config = yaml.safe_load(input_stream)
218250

219-
main(config)
251+
main(config, run_backtest=args.run_backtest)

qlib/rl/data/pickle_styled.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,16 @@ def _find_pickle(filename_without_suffix: Path) -> Path:
8383

8484
@lru_cache(maxsize=10) # 10 * 40M = 400MB
8585
def _read_pickle(filename_without_suffix: Path) -> pd.DataFrame:
86-
return pd.read_pickle(_find_pickle(filename_without_suffix))
86+
df = pd.read_pickle(_find_pickle(filename_without_suffix))
87+
index_cols = df.index.names
88+
89+
df = df.reset_index()
90+
for date_col_name in ["date", "datetime"]:
91+
if date_col_name in df:
92+
df[date_col_name] = pd.to_datetime(df[date_col_name])
93+
df = df.set_index(index_cols)
94+
95+
return df
8796

8897

8998
class SimpleIntradayBacktestData(BaseIntradayBacktestData):
@@ -161,6 +170,7 @@ def __init__(
161170
time_index: pd.Index,
162171
) -> None:
163172
proc = _read_pickle((data_dir if isinstance(data_dir, Path) else Path(data_dir)) / stock_id)
173+
164174
# We have to infer the names here because,
165175
# unfortunately they are not included in the original data.
166176
cnames = _infer_processed_data_column_names(feature_dim)

qlib/rl/order_execution/reward.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,13 @@ class PAPenaltyReward(Reward[SAOEState]):
2121
----------
2222
penalty
2323
The penalty for large volume in a short time.
24+
scale
25+
The weight used to scale up or down the reward.
2426
"""
2527

26-
def __init__(self, penalty: float = 100.0):
28+
def __init__(self, penalty: float = 100.0, scale: float = 1.0) -> None:
2729
self.penalty = penalty
30+
self.scale = scale
2831

2932
def reward(self, simulator_state: SAOEState) -> float:
3033
whole_order = simulator_state.order.amount
@@ -43,4 +46,4 @@ def reward(self, simulator_state: SAOEState) -> float:
4346

4447
self.log("reward/pa", pa)
4548
self.log("reward/penalty", penalty)
46-
return reward
49+
return reward * self.scale

qlib/rl/order_execution/simulator_simple.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]):
3636
----------
3737
order
3838
The seed to start an SAOE simulator is an order.
39+
data_granularity
40+
Number of ticks between consecutive data entries.
3941
ticks_per_step
4042
How many ticks per step.
4143
data_dir
@@ -71,14 +73,17 @@ def __init__(
7173
self,
7274
order: Order,
7375
data_dir: Path,
76+
data_granularity: int = 1,
7477
ticks_per_step: int = 30,
7578
deal_price_type: DealPriceType = "close",
7679
vol_threshold: Optional[float] = None,
7780
) -> None:
7881
super().__init__(initial=order)
7982

83+
assert ticks_per_step % data_granularity == 0
84+
8085
self.order = order
81-
self.ticks_per_step: int = ticks_per_step
86+
self.ticks_per_step: int = ticks_per_step // data_granularity
8287
self.deal_price_type = deal_price_type
8388
self.vol_threshold = vol_threshold
8489
self.data_dir = data_dir
@@ -132,6 +137,8 @@ def step(self, amount: float) -> None:
132137
ticks_position = self.position - np.cumsum(exec_vol)
133138

134139
self.position -= exec_vol.sum()
140+
if abs(self.position) < 1e-6:
141+
self.position = 0.0
135142
if self.position < -EPS or (exec_vol < -EPS).any():
136143
raise ValueError(f"Execution volume is invalid: {exec_vol} (position = {self.position})")
137144

qlib/rl/trainer/__init__.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,17 @@
44
"""Train, test, inference utilities."""
55

66
from .api import backtest, train
7-
from .callbacks import Checkpoint, EarlyStopping
7+
from .callbacks import Checkpoint, EarlyStopping, MetricsWriter
88
from .trainer import Trainer
99
from .vessel import TrainingVessel, TrainingVesselBase
1010

11-
__all__ = ["Trainer", "TrainingVessel", "TrainingVesselBase", "Checkpoint", "EarlyStopping", "train", "backtest"]
11+
__all__ = [
12+
"Trainer",
13+
"TrainingVessel",
14+
"TrainingVesselBase",
15+
"Checkpoint",
16+
"EarlyStopping",
17+
"MetricsWriter",
18+
"train",
19+
"backtest",
20+
]

0 commit comments

Comments
 (0)