diff --git a/examples/rl_order_execution/README.md b/examples/rl_order_execution/README.md index 197b1605f3..00ed0757a6 100644 --- a/examples/rl_order_execution/README.md +++ b/examples/rl_order_execution/README.md @@ -14,9 +14,10 @@ python -m qlib.run.get_data qlib_data qlib_data --target_dir ./data/bin --region To run codes in this example, we need data in pickle format. To achieve this, run following commands (might need a few minutes to finish): +[//]: # (TODO: Instead of dumping dataframe with different format (like `_gen_dataset` and `_gen_day_dataset` in `qlib/contrib/data/highfreq_provider.py`), we encourage to implement different subclass of `Dataset` and `DataHandler`. This will keep the workflow cleaner and interfaces more consistent, and move all the complexity to the subclass.) + ``` python scripts/gen_pickle_data.py -c scripts/pickle_data_config.yml -python scripts/collect_pickle_dataframe.py python scripts/gen_training_orders.py python scripts/merge_orders.py ``` @@ -27,8 +28,7 @@ When finished, the structure under `data/` should be: data ├── bin ├── orders -├── pickle -└── pickle_dataframe +└── pickle ``` ## Training diff --git a/examples/rl_order_execution/exp_configs/backtest_opds.yml b/examples/rl_order_execution/exp_configs/backtest_opds.yml index c1c9b929ac..1cd767f2ba 100755 --- a/examples/rl_order_execution/exp_configs/backtest_opds.yml +++ b/examples/rl_order_execution/exp_configs/backtest_opds.yml @@ -3,15 +3,6 @@ start_time: "9:30" end_time: "14:54" qlib: provider_uri_5min: ./data/bin/ - feature_root_dir: ./data/pickle/ - feature_columns_today: [ - "$open", "$high", "$low", "$close", "$vwap", "$bid", "$ask", "$volume", - "$bidV", "$bidV1", "$bidV3", "$bidV5", "$askV", "$askV1", "$askV3", "$askV5" - ] - feature_columns_yesterday: [ - "$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1", "$bid_1", "$ask_1", "$volume_1", - "$bidV_1", "$bidV1_1", "$bidV3_1", "$bidV5_1", "$askV_1", "$askV1_1", "$askV3_1", "$askV5_1" - ] exchange: limit_threshold: null deal_price: ["$close", "$close"] @@ -45,10 +36,12 @@ strategies: data_ticks: 48 max_step: 8 processed_data_provider: - class: PickleProcessedDataProvider + class: HandlerProcessedDataProvider kwargs: - data_dir: ./data/pickle_dataframe/feature - module_path: qlib.rl.data.pickle_styled + data_dir: ./data/pickle/ + feature_columns_today: ["$high", "$low", "$open", "$close", "$volume"] + feature_columns_yesterday: ["$high_1", "$low_1", "$open_1", "$close_1", "$volume_1"] + module_path: qlib.rl.data.native module_path: qlib.rl.order_execution.interpreter module_path: qlib.rl.order_execution.strategy 30min: diff --git a/examples/rl_order_execution/exp_configs/backtest_ppo.yml b/examples/rl_order_execution/exp_configs/backtest_ppo.yml index 1298626b5e..7932b91497 100755 --- a/examples/rl_order_execution/exp_configs/backtest_ppo.yml +++ b/examples/rl_order_execution/exp_configs/backtest_ppo.yml @@ -3,15 +3,6 @@ start_time: "9:30" end_time: "14:54" qlib: provider_uri_5min: ./data/bin/ - feature_root_dir: ./data/pickle/ - feature_columns_today: [ - "$open", "$high", "$low", "$close", "$vwap", "$bid", "$ask", "$volume", - "$bidV", "$bidV1", "$bidV3", "$bidV5", "$askV", "$askV1", "$askV3", "$askV5" - ] - feature_columns_yesterday: [ - "$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1", "$bid_1", "$ask_1", "$volume_1", - "$bidV_1", "$bidV1_1", "$bidV3_1", "$bidV5_1", "$askV_1", "$askV1_1", "$askV3_1", "$askV5_1" - ] exchange: limit_threshold: null deal_price: ["$close", "$close"] @@ -45,10 +36,12 @@ strategies: data_ticks: 48 max_step: 8 processed_data_provider: - class: PickleProcessedDataProvider + class: HandlerProcessedDataProvider kwargs: - data_dir: ./data/pickle_dataframe/feature - module_path: qlib.rl.data.pickle_styled + data_dir: ./data/pickle/ + feature_columns_today: ["$high", "$low", "$open", "$close", "$volume"] + feature_columns_yesterday: ["$high_1", "$low_1", "$open_1", "$close_1", "$volume_1"] + module_path: qlib.rl.data.native module_path: qlib.rl.order_execution.interpreter module_path: qlib.rl.order_execution.strategy 30min: diff --git a/examples/rl_order_execution/exp_configs/backtest_twap.yml b/examples/rl_order_execution/exp_configs/backtest_twap.yml index a797e3fd84..99efd30c24 100755 --- a/examples/rl_order_execution/exp_configs/backtest_twap.yml +++ b/examples/rl_order_execution/exp_configs/backtest_twap.yml @@ -3,15 +3,6 @@ start_time: "9:30" end_time: "14:54" qlib: provider_uri_5min: ./data/bin/ - feature_root_dir: ./data/pickle/ - feature_columns_today: [ - "$open", "$high", "$low", "$close", "$vwap", "$bid", "$ask", "$volume", - "$bidV", "$bidV1", "$bidV3", "$bidV5", "$askV", "$askV1", "$askV3", "$askV5" - ] - feature_columns_yesterday: [ - "$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1", "$bid_1", "$ask_1", "$volume_1", - "$bidV_1", "$bidV1_1", "$bidV3_1", "$bidV5_1", "$askV_1", "$askV1_1", "$askV3_1", "$askV5_1" - ] exchange: limit_threshold: null deal_price: ["$close", "$close"] diff --git a/examples/rl_order_execution/exp_configs/train_opds.yml b/examples/rl_order_execution/exp_configs/train_opds.yml index c69896474c..9be2618bee 100755 --- a/examples/rl_order_execution/exp_configs/train_opds.yml +++ b/examples/rl_order_execution/exp_configs/train_opds.yml @@ -3,8 +3,8 @@ simulator: time_per_step: 30 vol_limit: null env: - concurrency: 48 - parallel_mode: shmem + concurrency: 32 + parallel_mode: dummy action_interpreter: class: CategoricalActionInterpreter kwargs: @@ -18,10 +18,13 @@ state_interpreter: data_ticks: 48 # 48 = 240 min / 5 min max_step: 8 processed_data_provider: - class: PickleProcessedDataProvider - module_path: qlib.rl.data.pickle_styled + class: HandlerProcessedDataProvider kwargs: - data_dir: ./data/pickle_dataframe/feature + data_dir: ./data/pickle/ + feature_columns_today: ["$high", "$low", "$open", "$close", "$volume"] + feature_columns_yesterday: ["$high_1", "$low_1", "$open_1", "$close_1", "$volume_1"] + backtest: false + module_path: qlib.rl.data.native module_path: qlib.rl.order_execution.interpreter reward: class: PAPenaltyReward @@ -32,7 +35,9 @@ reward: data: source: order_dir: ./data/orders - data_dir: ./data/pickle_dataframe/backtest + feature_root_dir: ./data/pickle/ + feature_columns_today: ["$close0", "$volume0"] + feature_columns_yesterday: [] total_time: 240 default_start_time_index: 0 default_end_time_index: 235 diff --git a/examples/rl_order_execution/exp_configs/train_ppo.yml b/examples/rl_order_execution/exp_configs/train_ppo.yml index d0b2722384..5d0eeea277 100755 --- a/examples/rl_order_execution/exp_configs/train_ppo.yml +++ b/examples/rl_order_execution/exp_configs/train_ppo.yml @@ -3,8 +3,8 @@ simulator: time_per_step: 30 vol_limit: null env: - concurrency: 48 - parallel_mode: shmem + concurrency: 32 + parallel_mode: dummy action_interpreter: class: CategoricalActionInterpreter kwargs: @@ -18,10 +18,13 @@ state_interpreter: data_ticks: 48 # 48 = 240 min / 5 min max_step: 8 processed_data_provider: - class: PickleProcessedDataProvider - module_path: qlib.rl.data.pickle_styled + class: HandlerProcessedDataProvider kwargs: - data_dir: ./data/pickle_dataframe/feature + data_dir: ./data/pickle/ + feature_columns_today: ["$high", "$low", "$open", "$close", "$volume"] + feature_columns_yesterday: ["$high_1", "$low_1", "$open_1", "$close_1", "$volume_1"] + backtest: false + module_path: qlib.rl.data.native module_path: qlib.rl.order_execution.interpreter reward: class: PPOReward @@ -33,7 +36,9 @@ reward: data: source: order_dir: ./data/orders - data_dir: ./data/pickle_dataframe/backtest + feature_root_dir: ./data/pickle/ + feature_columns_today: ["$close0", "$volume0"] + feature_columns_yesterday: [] total_time: 240 default_start_time_index: 0 default_end_time_index: 235 diff --git a/examples/rl_order_execution/scripts/collect_pickle_dataframe.py b/examples/rl_order_execution/scripts/collect_pickle_dataframe.py deleted file mode 100755 index 4b02c0d36e..0000000000 --- a/examples/rl_order_execution/scripts/collect_pickle_dataframe.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -import os -import pickle -import pandas as pd -from joblib import Parallel, delayed - -os.makedirs(os.path.join("data", "pickle_dataframe"), exist_ok=True) - - -def _collect(df: pd.DataFrame, instrument: str, tag: str) -> None: - cur = df[df["instrument"] == instrument].sort_values(by=["datetime"]) - cur = cur.set_index(["instrument", "datetime", "date"]) - pickle.dump(cur, open(os.path.join("data", "pickle_dataframe", tag, f"{instrument}.pkl"), "wb")) - - -for tag in ("backtest", "feature"): - df = pickle.load(open(os.path.join("data", "pickle", f"{tag}.pkl"), "rb")) - df = pd.concat(list(df.values())).reset_index() - df["date"] = df["datetime"].dt.date.astype("datetime64") - instruments = sorted(set(df["instrument"])) - - os.makedirs(os.path.join("data", "pickle_dataframe", tag), exist_ok=True) - - Parallel(n_jobs=-1, verbose=10)(delayed(_collect)(df, instrument, tag) for instrument in instruments) diff --git a/examples/rl_order_execution/scripts/gen_training_orders.py b/examples/rl_order_execution/scripts/gen_training_orders.py index 5bca0e4cad..85217d717a 100755 --- a/examples/rl_order_execution/scripts/gen_training_orders.py +++ b/examples/rl_order_execution/scripts/gen_training_orders.py @@ -4,17 +4,22 @@ import os import numpy as np import pandas as pd -from tqdm import tqdm + from pathlib import Path -DATA_PATH = Path(os.path.join("data", "pickle_dataframe", "backtest")) +DATA_PATH = Path(os.path.join("data", "pickle", "backtest")) OUTPUT_PATH = Path(os.path.join("data", "orders")) -def generate_order(stock: str, start_idx: int, end_idx: int) -> None: - df = pd.read_pickle(DATA_PATH / f"{stock}.pkl") +def generate_order(stock: str, start_idx: int, end_idx: int) -> bool: + dataset = pd.read_pickle(DATA_PATH / f"{stock}.pkl") + df = dataset.handler.fetch(level=None).reset_index() + if len(df) == 0 or df.isnull().values.any() or min(df["$volume0"]) < 1e-5: + return False + + df["date"] = df["datetime"].dt.date.astype("datetime64") + df = df.set_index(["instrument", "datetime", "date"]) df = df.groupby("date").take(range(start_idx, end_idx)).droplevel(level=0) - div = df["$volume0"].rolling((end_idx - start_idx) * 60).mean().shift(1).groupby(level="date").transform("first") order_all = pd.DataFrame(df.groupby(level=(2, 0)).mean().dropna()) order_all["amount"] = np.random.lognormal(-3.28, 1.14) * order_all["$volume0"] @@ -32,11 +37,17 @@ def generate_order(stock: str, start_idx: int, end_idx: int) -> None: os.makedirs(path, exist_ok=True) if len(order) > 0: order.to_pickle(path / f"{stock}.pkl.target") + return True np.random.seed(1234) file_list = sorted(os.listdir(DATA_PATH)) stocks = [f.replace(".pkl", "") for f in file_list] -stocks = sorted(np.random.choice(stocks, size=100, replace=False)) -for stock in tqdm(stocks): - generate_order(stock, 0, 240 // 5 - 1) +np.random.shuffle(stocks) + +cnt = 0 +for stock in stocks: + if generate_order(stock, 0, 240 // 5 - 1): + cnt += 1 + if cnt == 100: + break diff --git a/qlib/rl/contrib/backtest.py b/qlib/rl/contrib/backtest.py index 6fafa94282..ee0942877f 100644 --- a/qlib/rl/contrib/backtest.py +++ b/qlib/rl/contrib/backtest.py @@ -154,12 +154,7 @@ def single_with_simulator( ------- If generate_report is True, return execution records and the generated report. Otherwise, return only records. """ - if split == "stock": - stock_id = orders.iloc[0].instrument - init_qlib(backtest_config["qlib"], part=stock_id) - else: - day = orders.iloc[0].datetime - init_qlib(backtest_config["qlib"], part=day) + init_qlib(backtest_config["qlib"]) stocks = orders.instrument.unique().tolist() @@ -253,12 +248,7 @@ def single_with_collect_data_loop( If generate_report is True, return execution records and the generated report. Otherwise, return only records. """ - if split == "stock": - stock_id = orders.iloc[0].instrument - init_qlib(backtest_config["qlib"], part=stock_id) - else: - day = orders.iloc[0].datetime - init_qlib(backtest_config["qlib"], part=day) + init_qlib(backtest_config["qlib"]) trade_start_time = orders["datetime"].min() trade_end_time = orders["datetime"].max() diff --git a/qlib/rl/contrib/train_onpolicy.py b/qlib/rl/contrib/train_onpolicy.py index d131ff244b..204c933eff 100644 --- a/qlib/rl/contrib/train_onpolicy.py +++ b/qlib/rl/contrib/train_onpolicy.py @@ -1,5 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +from __future__ import annotations + import argparse import os import random @@ -9,13 +11,12 @@ import numpy as np import pandas as pd -import qlib import torch import yaml from qlib.backtest import Order from qlib.backtest.decision import OrderDir from qlib.constant import ONE_MIN -from qlib.rl.data.pickle_styled import load_simple_intraday_backtest_data +from qlib.rl.data.native import load_handler_intraday_processed_data from qlib.rl.interpreter import ActionInterpreter, StateInterpreter from qlib.rl.order_execution import SingleAssetOrderExecutionSimple from qlib.rl.reward import Reward @@ -49,19 +50,17 @@ def _read_orders(order_dir: Path) -> pd.DataFrame: class LazyLoadDataset(Dataset): def __init__( self, + data_dir: str, order_file_path: Path, - data_dir: Path, default_start_time_index: int, default_end_time_index: int, ) -> None: self._default_start_time_index = default_start_time_index self._default_end_time_index = default_end_time_index - self._order_file_path = order_file_path self._order_df = _read_orders(order_file_path).reset_index() - - self._data_dir = data_dir self._ticks_index: Optional[pd.DatetimeIndex] = None + self._data_dir = Path(data_dir) def __len__(self) -> int: return len(self._order_df) @@ -74,12 +73,17 @@ def __getitem__(self, index: int) -> Order: # TODO: We only load ticks index once based on the assumption that ticks index of different dates # TODO: in one experiment are all the same. If that assumption is not hold, we need to load ticks index # TODO: of all dates. - backtest_data = load_simple_intraday_backtest_data( + + data = load_handler_intraday_processed_data( data_dir=self._data_dir, stock_id=row["instrument"], date=date, + feature_columns_today=[], + feature_columns_yesterday=[], + backtest=True, + index_only=True, ) - self._ticks_index = [t - date for t in backtest_data.get_time_index()] + self._ticks_index = [t - date for t in data.today.index] order = Order( stock_id=row["instrument"], @@ -104,8 +108,6 @@ def train_and_test( run_training: bool, run_backtest: bool, ) -> None: - qlib.init() - order_root_path = Path(data_config["source"]["order_dir"]) data_granularity = simulator_config.get("data_granularity", 1) @@ -113,10 +115,11 @@ def train_and_test( def _simulator_factory_simple(order: Order) -> SingleAssetOrderExecutionSimple: return SingleAssetOrderExecutionSimple( order=order, - data_dir=Path(data_config["source"]["data_dir"]), - ticks_per_step=simulator_config["time_per_step"], + data_dir=data_config["source"]["feature_root_dir"], + feature_columns_today=data_config["source"]["feature_columns_today"], + feature_columns_yesterday=data_config["source"]["feature_columns_yesterday"], data_granularity=data_granularity, - deal_price_type=data_config["source"].get("deal_price_column", "close"), + ticks_per_step=simulator_config["time_per_step"], vol_threshold=simulator_config["vol_limit"], ) @@ -126,8 +129,8 @@ def _simulator_factory_simple(order: Order) -> SingleAssetOrderExecutionSimple: if run_training: train_dataset, valid_dataset = [ LazyLoadDataset( + data_dir=data_config["source"]["feature_root_dir"], order_file_path=order_root_path / tag, - data_dir=Path(data_config["source"]["data_dir"]), default_start_time_index=data_config["source"]["default_start_time_index"] // data_granularity, default_end_time_index=data_config["source"]["default_end_time_index"] // data_granularity, ) @@ -178,8 +181,8 @@ def _simulator_factory_simple(order: Order) -> SingleAssetOrderExecutionSimple: if run_backtest: test_dataset = LazyLoadDataset( + data_dir=data_config["source"]["feature_root_dir"], order_file_path=order_root_path / "test", - data_dir=Path(data_config["source"]["data_dir"]), default_start_time_index=data_config["source"]["default_start_time_index"] // data_granularity, default_end_time_index=data_config["source"]["default_end_time_index"] // data_granularity, ) diff --git a/qlib/rl/data/integration.py b/qlib/rl/data/integration.py index 58311367f4..e123b6c8cf 100644 --- a/qlib/rl/data/integration.py +++ b/qlib/rl/data/integration.py @@ -8,48 +8,14 @@ from __future__ import annotations -import pickle from pathlib import Path -from typing import List -import cachetools -import numpy as np -import pandas as pd import qlib from qlib.constant import REG_CN from qlib.contrib.ops.high_freq import BFillNan, Cut, Date, DayCumsum, DayLast, FFillNan, IsInf, IsNull, Select -from qlib.data.dataset import DatasetH -dataset = None - -class DataWrapper: - def __init__( - self, - feature_dataset: DatasetH, - backtest_dataset: DatasetH, - columns_today: List[str], - columns_yesterday: List[str], - _internal: bool = False, - ): - assert _internal, "Init function of data wrapper is for internal use only." - - self.feature_dataset = feature_dataset - self.backtest_dataset = backtest_dataset - self.columns_today = columns_today - self.columns_yesterday = columns_yesterday - - @cachetools.cached( # type: ignore - cache=cachetools.LRUCache(100), - key=lambda _, stock_id, date, backtest: (stock_id, date.replace(hour=0, minute=0, second=0), backtest), - ) - def get(self, stock_id: str, date: pd.Timestamp, backtest: bool = False) -> pd.DataFrame: - start_time, end_time = date.replace(hour=0, minute=0, second=0), date.replace(hour=23, minute=59, second=59) - dataset = self.backtest_dataset if backtest else self.feature_dataset - return dataset.handler.fetch(pd.IndexSlice[stock_id, start_time:end_time], level=None) - - -def init_qlib(qlib_config: dict, part: str | None = None) -> None: +def init_qlib(qlib_config: dict) -> None: """Initialize necessary resource to launch the workflow, including data direction, feature columns, etc.. Parameters @@ -72,12 +38,8 @@ def init_qlib(qlib_config: dict, part: str | None = None) -> None: "$bidV_1", "$bidV1_1", "$bidV3_1", "$bidV5_1", "$askV_1", "$askV1_1", "$askV3_1", "$askV5_1", ], } - part - Identifying which part (stock / date) to load. """ - global dataset # pylint: disable=W0603 - def _convert_to_path(path: str | Path) -> Path: return path if isinstance(path, Path) else Path(path) @@ -118,47 +80,3 @@ def _convert_to_path(path: str | Path) -> Path: redis_port=-1, clear_mem_cache=False, # init_qlib will be called for multiple times. Keep the cache for improving performance ) - - if part == "skip": - return - - # this won't work if it's put outside in case of multiprocessing - from qlib.data import D # noqa pylint: disable=C0415,W0611 - - if part is None: - feature_path = Path(qlib_config["feature_root_dir"]) / "feature.pkl" - backtest_path = Path(qlib_config["feature_root_dir"]) / "backtest.pkl" - else: - feature_path = Path(qlib_config["feature_root_dir"]) / "feature" / (part + ".pkl") - backtest_path = Path(qlib_config["feature_root_dir"]) / "backtest" / (part + ".pkl") - - with feature_path.open("rb") as f: - feature_dataset = pickle.load(f) - with backtest_path.open("rb") as f: - backtest_dataset = pickle.load(f) - - dataset = DataWrapper( - feature_dataset, - backtest_dataset, - qlib_config["feature_columns_today"], - qlib_config["feature_columns_yesterday"], - _internal=True, - ) - - -def fetch_features(stock_id: str, date: pd.Timestamp, yesterday: bool = False, backtest: bool = False) -> pd.DataFrame: - assert dataset is not None, "You must call init_qlib() before doing this." - - if backtest: - fields = ["$close", "$volume"] - else: - fields = dataset.columns_yesterday if yesterday else dataset.columns_today - - data = dataset.get(stock_id, date, backtest) - if data is None or len(data) == 0: - # create a fake index, but RL doesn't care about index - data = pd.DataFrame(0.0, index=np.arange(240), columns=fields, dtype=np.float32) # FIXME: hardcode here - else: - data = data.rename(columns={c: c.rstrip("0") for c in data.columns}) - data = data[fields] - return data diff --git a/qlib/rl/data/native.py b/qlib/rl/data/native.py index f09d909bc8..ceb5408829 100644 --- a/qlib/rl/data/native.py +++ b/qlib/rl/data/native.py @@ -2,17 +2,29 @@ # Licensed under the MIT License. from __future__ import annotations -from typing import cast +from pathlib import Path +from typing import cast, List import cachetools import pandas as pd +import pickle +import os from qlib.backtest import Exchange, Order from qlib.backtest.decision import TradeRange, TradeRangeByTime -from qlib.rl.order_execution.utils import get_ticks_slice - +from qlib.constant import EPS_T from .base import BaseIntradayBacktestData, BaseIntradayProcessedData, ProcessedDataProvider -from .integration import fetch_features + + +def get_ticks_slice( + ticks_index: pd.DatetimeIndex, + start: pd.Timestamp, + end: pd.Timestamp, + include_end: bool = False, +) -> pd.DatetimeIndex: + if not include_end: + end = end - EPS_T + return ticks_index[ticks_index.slice_indexer(start, end)] class IntradayBacktestData(BaseIntradayBacktestData): @@ -71,6 +83,31 @@ def get_time_index(self) -> pd.DatetimeIndex: return pd.DatetimeIndex([e[1] for e in list(self._exchange.quote_df.index)]) +class DataframeIntradayBacktestData(BaseIntradayBacktestData): + """Backtest data from dataframe""" + + def __init__(self, df: pd.DataFrame, price_column: str = "$close0", volume_column: str = "$volume0") -> None: + self.df = df + self.price_column = price_column + self.volume_column = volume_column + + def __repr__(self) -> str: + with pd.option_context("memory_usage", False, "display.max_info_columns", 1, "display.large_repr", "info"): + return f"{self.__class__.__name__}({self.df})" + + def __len__(self) -> int: + return len(self.df) + + def get_deal_price(self) -> pd.Series: + return self.df[self.price_column] + + def get_volume(self) -> pd.Series: + return self.df[self.volume_column] + + def get_time_index(self) -> pd.DatetimeIndex: + return cast(pd.DatetimeIndex, self.df.index) + + @cachetools.cached( # type: ignore cache=cachetools.LRUCache(100), key=lambda order, _, __: order.key_by_day, @@ -103,13 +140,18 @@ def load_backtest_data( return backtest_data -class NTIntradayProcessedData(BaseIntradayProcessedData): - """Subclass of IntradayProcessedData. Used to handle NT style data.""" +class HandlerIntradayProcessedData(BaseIntradayProcessedData): + """Subclass of IntradayProcessedData. Used to handle handler (bin format) style data.""" def __init__( self, + data_dir: Path, stock_id: str, date: pd.Timestamp, + feature_columns_today: List[str], + feature_columns_yesterday: List[str], + backtest: bool = False, + index_only: bool = False, ) -> None: def _drop_stock_id(df: pd.DataFrame) -> pd.DataFrame: df = df.reset_index() @@ -117,8 +159,18 @@ def _drop_stock_id(df: pd.DataFrame) -> pd.DataFrame: df = df.drop(columns=["instrument"]) return df.set_index(["datetime"]) - self.today = _drop_stock_id(fetch_features(stock_id, date)) - self.yesterday = _drop_stock_id(fetch_features(stock_id, date, yesterday=True)) + path = os.path.join(data_dir, "backtest" if backtest else "feature", f"{stock_id}.pkl") + start_time, end_time = date.replace(hour=0, minute=0, second=0), date.replace(hour=23, minute=59, second=59) + with open(path, "rb") as fstream: + dataset = pickle.load(fstream) + data = dataset.handler.fetch(pd.IndexSlice[stock_id, start_time:end_time], level=None) + + if index_only: + self.today = _drop_stock_id(data[[]]) + self.yesterday = _drop_stock_id(data[[]]) + else: + self.today = _drop_stock_id(data[feature_columns_today]) + self.yesterday = _drop_stock_id(data[feature_columns_yesterday]) def __repr__(self) -> str: with pd.option_context("memory_usage", False, "display.max_info_columns", 1, "display.large_repr", "info"): @@ -127,12 +179,42 @@ def __repr__(self) -> str: @cachetools.cached( # type: ignore cache=cachetools.LRUCache(100), # 100 * 50K = 5MB + key=lambda data_dir, stock_id, date, feature_columns_today, feature_columns_yesterday, backtest, index_only: ( + stock_id, + date, + backtest, + index_only, + ), ) -def load_nt_intraday_processed_data(stock_id: str, date: pd.Timestamp) -> NTIntradayProcessedData: - return NTIntradayProcessedData(stock_id, date) +def load_handler_intraday_processed_data( + data_dir: Path, + stock_id: str, + date: pd.Timestamp, + feature_columns_today: List[str], + feature_columns_yesterday: List[str], + backtest: bool = False, + index_only: bool = False, +) -> HandlerIntradayProcessedData: + return HandlerIntradayProcessedData( + data_dir, stock_id, date, feature_columns_today, feature_columns_yesterday, backtest, index_only + ) -class NTProcessedDataProvider(ProcessedDataProvider): +class HandlerProcessedDataProvider(ProcessedDataProvider): + def __init__( + self, + data_dir: str, + feature_columns_today: List[str], + feature_columns_yesterday: List[str], + backtest: bool = False, + ) -> None: + super().__init__() + + self.data_dir = Path(data_dir) + self.feature_columns_today = feature_columns_today + self.feature_columns_yesterday = feature_columns_yesterday + self.backtest = backtest + def get_data( self, stock_id: str, @@ -140,4 +222,12 @@ def get_data( feature_dim: int, time_index: pd.Index, ) -> BaseIntradayProcessedData: - return load_nt_intraday_processed_data(stock_id, date) + return load_handler_intraday_processed_data( + self.data_dir, + stock_id, + date, + self.feature_columns_today, + self.feature_columns_yesterday, + backtest=self.backtest, + index_only=False, + ) diff --git a/qlib/rl/data/pickle_styled.py b/qlib/rl/data/pickle_styled.py index 3f21c08550..4905b026a2 100644 --- a/qlib/rl/data/pickle_styled.py +++ b/qlib/rl/data/pickle_styled.py @@ -158,8 +158,8 @@ def get_time_index(self) -> pd.DatetimeIndex: return cast(pd.DatetimeIndex, self.data.index) -class IntradayProcessedData(BaseIntradayProcessedData): - """Subclass of IntradayProcessedData. Used to handle Dataset Handler style data.""" +class PickleIntradayProcessedData(BaseIntradayProcessedData): + """Subclass of IntradayProcessedData. Used to handle pickle-styled data.""" def __init__( self, @@ -217,14 +217,14 @@ def load_simple_intraday_backtest_data( cache=cachetools.LRUCache(100), # 100 * 50K = 5MB key=lambda data_dir, stock_id, date, feature_dim, time_index: hashkey(data_dir, stock_id, date), ) -def load_pickled_intraday_processed_data( +def load_pickle_intraday_processed_data( data_dir: Path, stock_id: str, date: pd.Timestamp, feature_dim: int, time_index: pd.Index, ) -> BaseIntradayProcessedData: - return IntradayProcessedData(data_dir, stock_id, date, feature_dim, time_index) + return PickleIntradayProcessedData(data_dir, stock_id, date, feature_dim, time_index) class PickleProcessedDataProvider(ProcessedDataProvider): @@ -240,7 +240,7 @@ def get_data( feature_dim: int, time_index: pd.Index, ) -> BaseIntradayProcessedData: - return load_pickled_intraday_processed_data( + return load_pickle_intraday_processed_data( data_dir=self._data_dir, stock_id=stock_id, date=date, diff --git a/qlib/rl/order_execution/simulator_qlib.py b/qlib/rl/order_execution/simulator_qlib.py index ab6b463761..1417e2ab4a 100644 --- a/qlib/rl/order_execution/simulator_qlib.py +++ b/qlib/rl/order_execution/simulator_qlib.py @@ -67,7 +67,7 @@ def reset( cash_limit: Optional[float] = None, ) -> None: if qlib_config is not None: - init_qlib(qlib_config, part="skip") + init_qlib(qlib_config) strategy, self._executor = get_strategy_executor( start_time=order.date, diff --git a/qlib/rl/order_execution/simulator_simple.py b/qlib/rl/order_execution/simulator_simple.py index f1c09c1516..48aa03a170 100644 --- a/qlib/rl/order_execution/simulator_simple.py +++ b/qlib/rl/order_execution/simulator_simple.py @@ -3,17 +3,19 @@ from __future__ import annotations -from pathlib import Path -from typing import Any, cast, Optional +from typing import Any, cast, List, Optional import numpy as np import pandas as pd + +from pathlib import Path from qlib.backtest.decision import Order, OrderDir from qlib.constant import EPS, EPS_T, float_or_ndarray -from qlib.rl.data.pickle_styled import DealPriceType, load_simple_intraday_backtest_data +from qlib.rl.data.base import BaseIntradayBacktestData +from qlib.rl.data.native import DataframeIntradayBacktestData, load_handler_intraday_processed_data +from qlib.rl.data.pickle_styled import load_simple_intraday_backtest_data from qlib.rl.simulator import Simulator from qlib.rl.utils import LogLevel - from .state import SAOEMetrics, SAOEState __all__ = ["SingleAssetOrderExecutionSimple"] @@ -36,12 +38,16 @@ class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]): ---------- order The seed to start an SAOE simulator is an order. + data_dir + Path to load backtest data. + feature_columns_today + Columns of today's feature. + feature_columns_yesterday + Columns of yesterday's feature. data_granularity Number of ticks between consecutive data entries. ticks_per_step How many ticks per step. - data_dir - Path to load backtest data vol_threshold Maximum execution volume (divided by market execution volume). """ @@ -73,9 +79,10 @@ def __init__( self, order: Order, data_dir: Path, + feature_columns_today: List[str] = [], + feature_columns_yesterday: List[str] = [], data_granularity: int = 1, ticks_per_step: int = 30, - deal_price_type: DealPriceType = "close", vol_threshold: Optional[float] = None, ) -> None: super().__init__(initial=order) @@ -83,18 +90,13 @@ def __init__( assert ticks_per_step % data_granularity == 0 self.order = order + self.data_dir = data_dir + self.feature_columns_today = feature_columns_today + self.feature_columns_yesterday = feature_columns_yesterday self.ticks_per_step: int = ticks_per_step // data_granularity - self.deal_price_type = deal_price_type self.vol_threshold = vol_threshold - self.data_dir = data_dir - self.backtest_data = load_simple_intraday_backtest_data( - self.data_dir, - order.stock_id, - pd.Timestamp(order.start_time.date()), - self.deal_price_type, - order.direction, - ) + self.backtest_data = self.get_backtest_data() self.ticks_index = self.backtest_data.get_time_index() # Get time index available for trading @@ -118,6 +120,30 @@ def __init__( self.market_vol: Optional[np.ndarray] = None self.market_vol_limit: Optional[np.ndarray] = None + def get_backtest_data(self) -> BaseIntradayBacktestData: + try: + data = load_handler_intraday_processed_data( + data_dir=self.data_dir, + stock_id=self.order.stock_id, + date=pd.Timestamp(self.order.start_time.date()), + feature_columns_today=self.feature_columns_today, + feature_columns_yesterday=self.feature_columns_yesterday, + backtest=True, + index_only=False, + ) + return DataframeIntradayBacktestData(data.today) + except (AttributeError, FileNotFoundError): + # TODO: For compatibility with older versions of test scripts (tests/rl/test_saoe_simple.py) + # TODO: In the future, we should modify the data format used by the test script, + # TODO: and then delete this branch. + return load_simple_intraday_backtest_data( + self.data_dir / "backtest", + self.order.stock_id, + pd.Timestamp(self.order.start_time.date()), + "close", + self.order.direction, + ) + def step(self, amount: float) -> None: """Execute one step or SAOE. diff --git a/qlib/rl/order_execution/utils.py b/qlib/rl/order_execution/utils.py index 43517fe744..5a4fb78ff9 100644 --- a/qlib/rl/order_execution/utils.py +++ b/qlib/rl/order_execution/utils.py @@ -10,18 +10,7 @@ from qlib.backtest.decision import OrderDir from qlib.backtest.executor import BaseExecutor, NestedExecutor, SimulatorExecutor -from qlib.constant import EPS_T, float_or_ndarray - - -def get_ticks_slice( - ticks_index: pd.DatetimeIndex, - start: pd.Timestamp, - end: pd.Timestamp, - include_end: bool = False, -) -> pd.DatetimeIndex: - if not include_end: - end = end - EPS_T - return ticks_index[ticks_index.slice_indexer(start, end)] +from qlib.constant import float_or_ndarray def dataframe_append(df: pd.DataFrame, other: Any) -> pd.DataFrame: diff --git a/tests/rl/test_saoe_simple.py b/tests/rl/test_saoe_simple.py index 32d6b4d6e4..d1711bb289 100644 --- a/tests/rl/test_saoe_simple.py +++ b/tests/rl/test_saoe_simple.py @@ -31,7 +31,6 @@ ORDER_DIR = DATA_DIR / "order" / "valid_bidir" CN_DATA_DIR = DATA_ROOT_DIR / "cn" -CN_BACKTEST_DATA_DIR = CN_DATA_DIR / "backtest" CN_FEATURE_DATA_DIR = CN_DATA_DIR / "processed" CN_ORDER_DIR = CN_DATA_DIR / "order" / "test" CN_POLICY_WEIGHTS_DIR = CN_DATA_DIR / "weights" @@ -49,7 +48,7 @@ def test_pickle_data_inspect(): def test_simulator_first_step(): order = Order("AAL", 30.0, 0, pd.Timestamp("2013-12-11 00:00:00"), pd.Timestamp("2013-12-11 23:59:59")) - simulator = SingleAssetOrderExecutionSimple(order, BACKTEST_DATA_DIR) + simulator = SingleAssetOrderExecutionSimple(order, DATA_DIR) state = simulator.get_state() assert state.cur_time == pd.Timestamp("2013-12-11 09:30:00") assert state.position == 30.0 @@ -83,7 +82,7 @@ def test_simulator_first_step(): def test_simulator_stop_twap(): order = Order("AAL", 13.0, 0, pd.Timestamp("2013-12-11 00:00:00"), pd.Timestamp("2013-12-11 23:59:59")) - simulator = SingleAssetOrderExecutionSimple(order, BACKTEST_DATA_DIR) + simulator = SingleAssetOrderExecutionSimple(order, DATA_DIR) for _ in range(13): simulator.step(1.0) @@ -106,10 +105,10 @@ def test_simulator_stop_early(): order = Order("AAL", 1.0, 1, pd.Timestamp("2013-12-11 00:00:00"), pd.Timestamp("2013-12-11 23:59:59")) with pytest.raises(ValueError): - simulator = SingleAssetOrderExecutionSimple(order, BACKTEST_DATA_DIR) + simulator = SingleAssetOrderExecutionSimple(order, DATA_DIR) simulator.step(2.0) - simulator = SingleAssetOrderExecutionSimple(order, BACKTEST_DATA_DIR) + simulator = SingleAssetOrderExecutionSimple(order, DATA_DIR) simulator.step(1.0) with pytest.raises(AssertionError): @@ -119,7 +118,7 @@ def test_simulator_stop_early(): def test_simulator_start_middle(): order = Order("AAL", 15.0, 1, pd.Timestamp("2013-12-11 10:15:00"), pd.Timestamp("2013-12-11 15:44:59")) - simulator = SingleAssetOrderExecutionSimple(order, BACKTEST_DATA_DIR) + simulator = SingleAssetOrderExecutionSimple(order, DATA_DIR) assert len(simulator.ticks_for_order) == 330 assert simulator.cur_time == pd.Timestamp("2013-12-11 10:15:00") simulator.step(2.0) @@ -138,7 +137,7 @@ def test_simulator_start_middle(): def test_interpreter(): order = Order("AAL", 15.0, 1, pd.Timestamp("2013-12-11 10:15:00"), pd.Timestamp("2013-12-11 15:44:59")) - simulator = SingleAssetOrderExecutionSimple(order, BACKTEST_DATA_DIR) + simulator = SingleAssetOrderExecutionSimple(order, DATA_DIR) assert len(simulator.ticks_for_order) == 330 assert simulator.cur_time == pd.Timestamp("2013-12-11 10:15:00") @@ -219,7 +218,7 @@ def test_network_sanity(): # we won't check the correctness of networks here order = Order("AAL", 15.0, 1, pd.Timestamp("2013-12-11 9:30:00"), pd.Timestamp("2013-12-11 15:59:59")) - simulator = SingleAssetOrderExecutionSimple(order, BACKTEST_DATA_DIR) + simulator = SingleAssetOrderExecutionSimple(order, DATA_DIR) assert len(simulator.ticks_for_order) == 390 class EmulateEnvWrapper(NamedTuple): @@ -259,7 +258,7 @@ def test_twap_strategy(finite_env_type): csv_writer = CsvWriter(Path(__file__).parent / ".output") backtest( - partial(SingleAssetOrderExecutionSimple, data_dir=BACKTEST_DATA_DIR, ticks_per_step=30), + partial(SingleAssetOrderExecutionSimple, data_dir=DATA_DIR, ticks_per_step=30), state_interp, action_interp, orders, @@ -290,7 +289,7 @@ def test_cn_ppo_strategy(): csv_writer = CsvWriter(Path(__file__).parent / ".output") backtest( - partial(SingleAssetOrderExecutionSimple, data_dir=CN_BACKTEST_DATA_DIR, ticks_per_step=30), + partial(SingleAssetOrderExecutionSimple, data_dir=CN_DATA_DIR, ticks_per_step=30), state_interp, action_interp, orders, @@ -319,7 +318,7 @@ def test_ppo_train(): policy = PPO(network, state_interp.observation_space, action_interp.action_space, 1e-4) train( - partial(SingleAssetOrderExecutionSimple, data_dir=CN_BACKTEST_DATA_DIR, ticks_per_step=30), + partial(SingleAssetOrderExecutionSimple, data_dir=CN_DATA_DIR, ticks_per_step=30), state_interp, action_interp, orders,