Skip to content

Commit

Permalink
Refine DDG-DA (#1472)
Browse files Browse the repository at this point in the history
* Run ddg-da successfully

* Support include valid; More parameters

* Support L2 reg & visualization

* Blackformat

* Enable fill_method

* Support specify handler & optim dataset

* Fix Pylint
  • Loading branch information
you-n-g authored Apr 7, 2023
1 parent 40de672 commit 32c3070
Show file tree
Hide file tree
Showing 17 changed files with 457 additions and 39 deletions.
107 changes: 107 additions & 0 deletions examples/benchmarks_dynamic/DDG-DA/vis_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

sns.set(color_codes=True)
plt.rcParams["font.sans-serif"] = "SimHei"
plt.rcParams["axes.unicode_minus"] = False
from tqdm.auto import tqdm

# tqdm.pandas() # for progress_apply
# %matplotlib inline
# %load_ext autoreload


# # Meta Input

# +
with open("./internal_data_s20.pkl", "rb") as f:
data = pickle.load(f)

data.data_ic_df.columns.names = ["start_date", "end_date"]

data_sim = data.data_ic_df.droplevel(axis=1, level="end_date")

data_sim.index.name = "test datetime"
# -

plt.figure(figsize=(40, 20))
sns.heatmap(data_sim)

plt.figure(figsize=(40, 20))
sns.heatmap(data_sim.rolling(20).mean())

# # Meta Model

from qlib import auto_init

auto_init()
from qlib.workflow import R

exp = R.get_exp(experiment_name="DDG-DA")
meta_rec = exp.list_recorders(rtype="list", max_results=1)[0]
meta_m = meta_rec.load_object("model")

pd.DataFrame(meta_m.tn.twm.linear.weight.detach().numpy()).T[0].plot()

pd.DataFrame(meta_m.tn.twm.linear.weight.detach().numpy()).T[0].rolling(5).mean().plot()

# # Meta Output

# +
with open("./tasks_s20.pkl", "rb") as f:
tasks = pickle.load(f)

task_df = {}
for t in tasks:
test_seg = t["dataset"]["kwargs"]["segments"]["test"]
if None not in test_seg:
# The last rolling is skipped.
task_df[test_seg] = t["reweighter"].time_weight
task_df = pd.concat(task_df)

task_df.index.names = ["OS_start", "OS_end", "IS_start", "IS_end"]
task_df = task_df.droplevel(["OS_end", "IS_end"])
task_df = task_df.unstack("OS_start")
# -

plt.figure(figsize=(40, 20))
sns.heatmap(task_df.T)

plt.figure(figsize=(40, 20))
sns.heatmap(task_df.rolling(10).mean().T)

# # Sub Models
#
# NOTE:
# - this section assumes that the model is Linear model!!
# - Other models does not support this analysis

exp = R.get_exp(experiment_name="rolling_ds")


def show_linear_weight(exp):
coef_df = {}
for r in exp.list_recorders("list"):
t = r.load_object("task")
if None in t["dataset"]["kwargs"]["segments"]["test"]:
continue
m = r.load_object("params.pkl")
coef_df[t["dataset"]["kwargs"]["segments"]["test"]] = pd.Series(m.coef_)

coef_df = pd.concat(coef_df)

coef_df.index.names = ["test_start", "test_end", "coef_idx"]

coef_df = coef_df.droplevel("test_end").unstack("coef_idx").T

plt.figure(figsize=(40, 20))
sns.heatmap(coef_df)
plt.show()


show_linear_weight(R.get_exp(experiment_name="rolling_ds"))

show_linear_weight(R.get_exp(experiment_name="rolling_models"))
73 changes: 58 additions & 15 deletions examples/benchmarks_dynamic/DDG-DA/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
import fire
import sys
import pickle
from typing import Optional
from qlib import auto_init
from qlib.model.trainer import TrainerR
from qlib.typehint import Literal
from qlib.utils import init_instance_by_config
from qlib.workflow import R
from qlib.tests.data import GetData
Expand All @@ -30,18 +32,53 @@ class DDGDA:
- `rm -r mlruns`
"""

def __init__(self, sim_task_model="linear", forecast_model="linear"):
def __init__(
self,
sim_task_model: Literal["linear", "gbdt"] = "linear",
forecast_model: Literal["linear", "gbdt"] = "linear",
h_path: Optional[str] = None,
test_end: Optional[str] = None,
train_start: Optional[str] = None,
meta_1st_train_end: Optional[str] = None,
task_ext_conf: Optional[dict] = None,
alpha: float = 0.0,
proxy_hd: str = "handler_proxy.pkl",
):
"""
Parameters
----------
train_start: Optional[str]
the start datetime for data. It is used in training start time (for both tasks & meta learing)
test_end: Optional[str]
the end datetime for data. It is used in test end time
meta_1st_train_end: Optional[str]
the datetime of training end of the first meta_task
alpha: float
Setting the L2 regularization for ridge
The `alpha` is only passed to MetaModelDS (it is not passed to sim_task_model currently..)
"""
self.step = 20
# NOTE:
# the horizon must match the meaning in the base task template
self.horizon = 20
self.meta_exp_name = "DDG-DA"
self.sim_task_model = sim_task_model # The model to capture the distribution of data.
self.forecast_model = forecast_model # downstream forecasting models' type
self.rb_kwargs = {
"h_path": h_path,
"test_end": test_end,
"train_start": train_start,
"task_ext_conf": task_ext_conf,
}
self.alpha = alpha
self.meta_1st_train_end = meta_1st_train_end
self.proxy_hd = proxy_hd

def get_feature_importance(self):
# this must be lightGBM, because it needs to get the feature importance
rb = RollingBenchmark(model_type="gbdt")
rb = RollingBenchmark(model_type="gbdt", **self.rb_kwargs)
task = rb.basic_task()

with R.start(experiment_name="feature_importance"):
Expand Down Expand Up @@ -69,7 +106,7 @@ def dump_data_for_proxy_model(self):
fi = self.get_feature_importance()
col_selected = fi.nlargest(topk)

rb = RollingBenchmark(model_type=self.sim_task_model)
rb = RollingBenchmark(model_type=self.sim_task_model, **self.rb_kwargs)
task = rb.basic_task()
dataset = init_instance_by_config(task["dataset"])
prep_ds = dataset.prepare(slice(None), col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
Expand All @@ -96,7 +133,7 @@ def dump_data_for_proxy_model(self):
"kwargs": {"config": DIRNAME / "fea_label_df.pkl"},
}
)
handler.to_pickle(DIRNAME / "handler_proxy.pkl", dump_all=True)
handler.to_pickle(DIRNAME / self.proxy_hd, dump_all=True)

@property
def _internal_data_path(self):
Expand All @@ -108,7 +145,7 @@ def dump_meta_ipt(self):
This function will dump the input data for meta model
"""
# According to the experiments, the choice of the model type is very important for achieving good results
rb = RollingBenchmark(model_type=self.sim_task_model)
rb = RollingBenchmark(model_type=self.sim_task_model, **self.rb_kwargs)
sim_task = rb.basic_task()

if self.sim_task_model == "gbdt":
Expand All @@ -122,24 +159,27 @@ def dump_meta_ipt(self):
with self._internal_data_path.open("wb") as f:
pickle.dump(internal_data, f)

def train_meta_model(self):
def train_meta_model(self, fill_method="max"):
"""
training a meta model based on a simplified linear proxy model;
"""

# 1) leverage the simplified proxy forecasting model to train meta model.
# - Only the dataset part is important, in current version of meta model will integrate the
rb = RollingBenchmark(model_type=self.sim_task_model)
rb = RollingBenchmark(model_type=self.sim_task_model, **self.rb_kwargs)
sim_task = rb.basic_task()
train_start = self.rb_kwargs.get("train_start", "2008-01-01")
train_end = "2010-12-31" if self.meta_1st_train_end is None else self.meta_1st_train_end
test_start = (pd.Timestamp(train_end) + pd.Timedelta(days=1)).strftime("%Y-%m-%d")
proxy_forecast_model_task = {
# "model": "qlib.contrib.model.linear.LinearModel",
"dataset": {
"class": "qlib.data.dataset.DatasetH",
"kwargs": {
"handler": f"file://{(DIRNAME / 'handler_proxy.pkl').absolute()}",
"handler": f"file://{(DIRNAME / self.proxy_hd).absolute()}",
"segments": {
"train": ("2008-01-01", "2010-12-31"),
"test": ("2011-01-01", sim_task["dataset"]["kwargs"]["segments"]["test"][1]),
"train": (train_start, train_end),
"test": (test_start, sim_task["dataset"]["kwargs"]["segments"]["test"][1]),
},
},
},
Expand All @@ -156,7 +196,7 @@ def train_meta_model(self):
segments=0.62, # keep test period consistent with the dataset yaml
trunc_days=1 + self.horizon,
hist_step_n=30,
fill_method="max",
fill_method=fill_method,
rolling_ext_days=0,
)
# NOTE:
Expand All @@ -165,12 +205,15 @@ def train_meta_model(self):
# So the misalignment will not affect the effectiveness of the method.
with self._internal_data_path.open("rb") as f:
internal_data = pickle.load(f)

md = MetaDatasetDS(exp_name=internal_data, **kwargs)

# 3) train and logging meta model
with R.start(experiment_name=self.meta_exp_name):
R.log_params(**kwargs)
mm = MetaModelDS(step=self.step, hist_step_n=kwargs["hist_step_n"], lr=0.001, max_epoch=100, seed=43)
mm = MetaModelDS(
step=self.step, hist_step_n=kwargs["hist_step_n"], lr=0.001, max_epoch=100, seed=43, alpha=self.alpha
)
mm.fit(md)
R.save_objects(model=mm)

Expand Down Expand Up @@ -203,7 +246,7 @@ def meta_inference(self):
hist_step_n = int(param["hist_step_n"])
fill_method = param.get("fill_method", "max")

rb = RollingBenchmark(model_type=self.forecast_model)
rb = RollingBenchmark(model_type=self.forecast_model, **self.rb_kwargs)
task_l = rb.create_rolling_tasks()

# 2.2) create meta dataset for final dataset
Expand Down Expand Up @@ -233,13 +276,13 @@ def train_and_eval_tasks(self):
"""
with self._task_path.open("rb") as f:
tasks = pickle.load(f)
rb = RollingBenchmark(rolling_exp="rolling_ds", model_type=self.forecast_model)
rb = RollingBenchmark(rolling_exp="rolling_ds", model_type=self.forecast_model, **self.rb_kwargs)
rb.train_rolling_tasks(tasks)
rb.ens_rolling()
rb.update_rolling_rec()

def run_all(self):
# 1) file: handler_proxy.pkl
# 1) file: handler_proxy.pkl (self.proxy_hd)
self.dump_data_for_proxy_model()
# 2)
# file: internal_data_s20.pkl
Expand Down
53 changes: 51 additions & 2 deletions examples/benchmarks_dynamic/baseline/rolling_benchmark.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import Optional
from qlib.model.ens.ensemble import RollingEnsemble
from qlib.utils import init_instance_by_config
import fire
import yaml
import pandas as pd
from qlib import auto_init
from pathlib import Path
from tqdm.auto import tqdm
from qlib.model.trainer import TrainerR
from qlib.log import get_module_logger
from qlib.utils.data import update_config
from qlib.workflow import R
from qlib.tests.data import GetData

Expand All @@ -25,11 +29,40 @@ class RollingBenchmark:
"""

def __init__(self, rolling_exp="rolling_models", model_type="linear") -> None:
def __init__(
self,
rolling_exp: str = "rolling_models",
model_type: str = "linear",
h_path: Optional[str] = None,
train_start: Optional[str] = None,
test_end: Optional[str] = None,
task_ext_conf: Optional[dict] = None,
) -> None:
"""
Parameters
----------
rolling_exp : str
The name for the experiments for rolling
model_type : str
The model to be boosted.
h_path : Optional[str]
the dumped data handler;
test_end : Optional[str]
the test end for the data. It is typically used together with the handler
train_start : Optional[str]
the train start for the data. It is typically used together with the handler.
task_ext_conf : Optional[dict]
some option to update the
"""
self.step = 20
self.horizon = 20
self.rolling_exp = rolling_exp
self.model_type = model_type
self.h_path = h_path
self.train_start = train_start
self.test_end = test_end
self.logger = get_module_logger("RollingBenchmark")
self.task_ext_conf = task_ext_conf

def basic_task(self):
"""For fast training rolling"""
Expand All @@ -42,6 +75,10 @@ def basic_task(self):
h_path = DIRNAME / "linear_alpha158_handler_horizon{}.pkl".format(self.horizon)
else:
raise AssertionError("Model type is not supported!")

if self.h_path is not None:
h_path = Path(self.h_path)

with conf_path.open("r") as f:
conf = yaml.safe_load(f)

Expand All @@ -52,13 +89,25 @@ def basic_task(self):

task = conf["task"]

if self.task_ext_conf is not None:
task = update_config(task, self.task_ext_conf)

if not h_path.exists():
h_conf = task["dataset"]["kwargs"]["handler"]
h = init_instance_by_config(h_conf)
h.to_pickle(h_path, dump_all=True)

task["dataset"]["kwargs"]["handler"] = f"file://{h_path}"
task["record"] = ["qlib.workflow.record_temp.SignalRecord"]

if self.train_start is not None:
seg = task["dataset"]["kwargs"]["segments"]["train"]
task["dataset"]["kwargs"]["segments"]["train"] = pd.Timestamp(self.train_start), seg[1]

if self.test_end is not None:
seg = task["dataset"]["kwargs"]["segments"]["test"]
task["dataset"]["kwargs"]["segments"]["test"] = seg[0], pd.Timestamp(self.test_end)
self.logger.info(task)
return task

def create_rolling_tasks(self):
Expand Down Expand Up @@ -93,7 +142,7 @@ def update_rolling_rec(self):
"""
Evaluate the combined rolling results
"""
for rid, rec in R.list_recorders(experiment_name=self.COMB_EXP).items():
for _, rec in R.list_recorders(experiment_name=self.COMB_EXP).items():
for rt_cls in SigAnaRecord, PortAnaRecord:
rt = rt_cls(recorder=rec, skip_existing=True)
rt.generate()
Expand Down
Loading

0 comments on commit 32c3070

Please sign in to comment.