diff --git a/examples/rl_order_execution/exp_configs/backtest_opds.yml b/examples/rl_order_execution/exp_configs/backtest_opds.yml index 1cd767f2ba..068a39a5e0 100755 --- a/examples/rl_order_execution/exp_configs/backtest_opds.yml +++ b/examples/rl_order_execution/exp_configs/backtest_opds.yml @@ -1,6 +1,7 @@ order_file: ./data/orders/test_orders.pkl start_time: "9:30" end_time: "14:54" +data_granularity: "5min" qlib: provider_uri_5min: ./data/bin/ exchange: diff --git a/examples/rl_order_execution/exp_configs/backtest_ppo.yml b/examples/rl_order_execution/exp_configs/backtest_ppo.yml index 7932b91497..e3e3b907e7 100755 --- a/examples/rl_order_execution/exp_configs/backtest_ppo.yml +++ b/examples/rl_order_execution/exp_configs/backtest_ppo.yml @@ -1,6 +1,7 @@ order_file: ./data/orders/test_orders.pkl start_time: "9:30" end_time: "14:54" +data_granularity: "5min" qlib: provider_uri_5min: ./data/bin/ exchange: diff --git a/examples/rl_order_execution/exp_configs/backtest_twap.yml b/examples/rl_order_execution/exp_configs/backtest_twap.yml index 99efd30c24..e6230c2e6c 100755 --- a/examples/rl_order_execution/exp_configs/backtest_twap.yml +++ b/examples/rl_order_execution/exp_configs/backtest_twap.yml @@ -1,6 +1,7 @@ order_file: ./data/orders/test_orders.pkl start_time: "9:30" end_time: "14:54" +data_granularity: "5min" qlib: provider_uri_5min: ./data/bin/ exchange: diff --git a/qlib/rl/contrib/backtest.py b/qlib/rl/contrib/backtest.py index ee0942877f..60602c10d3 100644 --- a/qlib/rl/contrib/backtest.py +++ b/qlib/rl/contrib/backtest.py @@ -30,12 +30,13 @@ def _get_multi_level_executor_config( strategy_config: dict, cash_limit: float | None = None, generate_report: bool = False, + data_granularity: str = "1min", ) -> dict: executor_config = { "class": "SimulatorExecutor", "module_path": "qlib.backtest.executor", "kwargs": { - "time_per_step": "5min", # FIXME: move this into config + "time_per_step": data_granularity, "verbose": False, "trade_type": SimulatorExecutor.TT_PARAL if cash_limit is not None else SimulatorExecutor.TT_SERIAL, "generate_report": generate_report, @@ -176,13 +177,14 @@ def single_with_simulator( strategy_config=backtest_config["strategies"], cash_limit=cash_limit, generate_report=generate_report, + data_granularity=backtest_config["data_granularity"], ) exchange_config = copy.deepcopy(backtest_config["exchange"]) exchange_config.update( { "codes": stocks, - "freq": "5min", # FIXME: move this into config + "freq": backtest_config["data_granularity"], } ) @@ -197,7 +199,7 @@ def single_with_simulator( reports.append(simulator.report_dict) decisions += simulator.decisions - indicator_1day_objs = [report["indicator"]["1day"][1] for report in reports] + indicator_1day_objs = [report["indicator_dict"]["1day"][1] for report in reports] indicator_info = {k: v for obj in indicator_1day_objs for k, v in obj.order_indicator_his.items()} records = _convert_indicator_to_dataframe(indicator_info) assert records is None or not np.isnan(records["ffr"]).any() @@ -270,13 +272,14 @@ def single_with_collect_data_loop( strategy_config=backtest_config["strategies"], cash_limit=cash_limit, generate_report=generate_report, + data_granularity=backtest_config["data_granularity"], ) exchange_config = copy.deepcopy(backtest_config["exchange"]) exchange_config.update( { "codes": stocks, - "freq": "5min", # FIXME: move this into config + "freq": backtest_config["data_granularity"], } ) diff --git a/qlib/rl/contrib/naive_config_parser.py b/qlib/rl/contrib/naive_config_parser.py index a6409f828d..2255c7414a 100644 --- a/qlib/rl/contrib/naive_config_parser.py +++ b/qlib/rl/contrib/naive_config_parser.py @@ -100,6 +100,7 @@ def get_backtest_config_fromfile(path: str) -> dict: "multiplier": 1.0, "output_dir": "outputs_backtest/", "generate_report": False, + "data_granularity": "1min", } backtest_config = merge_a_into_b(a=backtest_config, b=backtest_config_default) diff --git a/qlib/rl/contrib/train_onpolicy.py b/qlib/rl/contrib/train_onpolicy.py index 204c933eff..cd5d0e55ef 100644 --- a/qlib/rl/contrib/train_onpolicy.py +++ b/qlib/rl/contrib/train_onpolicy.py @@ -5,6 +5,7 @@ import argparse import os import random +import sys import warnings from pathlib import Path from typing import cast, List, Optional @@ -208,6 +209,9 @@ def main(config: dict, run_training: bool, run_backtest: bool) -> None: if "seed" in config["runtime"]: seed_everything(config["runtime"]["seed"]) + for extra_module_path in config["env"].get("extra_module_paths", []): + sys.path.append(extra_module_path) + state_interpreter: StateInterpreter = init_instance_by_config(config["state_interpreter"]) action_interpreter: ActionInterpreter = init_instance_by_config(config["action_interpreter"]) reward: Reward = init_instance_by_config(config["reward"])