Skip to content

Commit

Permalink
Resolve RL FIXMES (#1503)
Browse files Browse the repository at this point in the history
* Solve several small FIXMEs left in RL

* Add TODO in example

* Minor bugfix

* black
  • Loading branch information
lihuoran authored May 17, 2023
1 parent 7234308 commit 8d60a6a
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 4 deletions.
1 change: 1 addition & 0 deletions examples/rl_order_execution/exp_configs/backtest_opds.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
order_file: ./data/orders/test_orders.pkl
start_time: "9:30"
end_time: "14:54"
data_granularity: "5min"
qlib:
provider_uri_5min: ./data/bin/
exchange:
Expand Down
1 change: 1 addition & 0 deletions examples/rl_order_execution/exp_configs/backtest_ppo.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
order_file: ./data/orders/test_orders.pkl
start_time: "9:30"
end_time: "14:54"
data_granularity: "5min"
qlib:
provider_uri_5min: ./data/bin/
exchange:
Expand Down
1 change: 1 addition & 0 deletions examples/rl_order_execution/exp_configs/backtest_twap.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
order_file: ./data/orders/test_orders.pkl
start_time: "9:30"
end_time: "14:54"
data_granularity: "5min"
qlib:
provider_uri_5min: ./data/bin/
exchange:
Expand Down
11 changes: 7 additions & 4 deletions qlib/rl/contrib/backtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,13 @@ def _get_multi_level_executor_config(
strategy_config: dict,
cash_limit: float | None = None,
generate_report: bool = False,
data_granularity: str = "1min",
) -> dict:
executor_config = {
"class": "SimulatorExecutor",
"module_path": "qlib.backtest.executor",
"kwargs": {
"time_per_step": "5min", # FIXME: move this into config
"time_per_step": data_granularity,
"verbose": False,
"trade_type": SimulatorExecutor.TT_PARAL if cash_limit is not None else SimulatorExecutor.TT_SERIAL,
"generate_report": generate_report,
Expand Down Expand Up @@ -176,13 +177,14 @@ def single_with_simulator(
strategy_config=backtest_config["strategies"],
cash_limit=cash_limit,
generate_report=generate_report,
data_granularity=backtest_config["data_granularity"],
)

exchange_config = copy.deepcopy(backtest_config["exchange"])
exchange_config.update(
{
"codes": stocks,
"freq": "5min", # FIXME: move this into config
"freq": backtest_config["data_granularity"],
}
)

Expand All @@ -197,7 +199,7 @@ def single_with_simulator(
reports.append(simulator.report_dict)
decisions += simulator.decisions

indicator_1day_objs = [report["indicator"]["1day"][1] for report in reports]
indicator_1day_objs = [report["indicator_dict"]["1day"][1] for report in reports]
indicator_info = {k: v for obj in indicator_1day_objs for k, v in obj.order_indicator_his.items()}
records = _convert_indicator_to_dataframe(indicator_info)
assert records is None or not np.isnan(records["ffr"]).any()
Expand Down Expand Up @@ -270,13 +272,14 @@ def single_with_collect_data_loop(
strategy_config=backtest_config["strategies"],
cash_limit=cash_limit,
generate_report=generate_report,
data_granularity=backtest_config["data_granularity"],
)

exchange_config = copy.deepcopy(backtest_config["exchange"])
exchange_config.update(
{
"codes": stocks,
"freq": "5min", # FIXME: move this into config
"freq": backtest_config["data_granularity"],
}
)

Expand Down
1 change: 1 addition & 0 deletions qlib/rl/contrib/naive_config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def get_backtest_config_fromfile(path: str) -> dict:
"multiplier": 1.0,
"output_dir": "outputs_backtest/",
"generate_report": False,
"data_granularity": "1min",
}
backtest_config = merge_a_into_b(a=backtest_config, b=backtest_config_default)

Expand Down
4 changes: 4 additions & 0 deletions qlib/rl/contrib/train_onpolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import argparse
import os
import random
import sys
import warnings
from pathlib import Path
from typing import cast, List, Optional
Expand Down Expand Up @@ -208,6 +209,9 @@ def main(config: dict, run_training: bool, run_backtest: bool) -> None:
if "seed" in config["runtime"]:
seed_everything(config["runtime"]["seed"])

for extra_module_path in config["env"].get("extra_module_paths", []):
sys.path.append(extra_module_path)

state_interpreter: StateInterpreter = init_instance_by_config(config["state_interpreter"])
action_interpreter: ActionInterpreter = init_instance_by_config(config["action_interpreter"])
reward: Reward = init_instance_by_config(config["reward"])
Expand Down

0 comments on commit 8d60a6a

Please sign in to comment.