Skip to content

Commit

Permalink
Merge pull request #97 from Nixtla/debug_mint
Browse files Browse the repository at this point in the history
MinTrace's protection to Schafer-Strimmer covariance, eliminated stat…
  • Loading branch information
kdgutier authored Oct 28, 2022
2 parents 4a84262 + 03a9e78 commit 437e5b1
Show file tree
Hide file tree
Showing 11 changed files with 301 additions and 134 deletions.
1 change: 0 additions & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ dependencies:
- numba
- pandas
- scikit-learn
- statsmodels
- pip
- pip:
- nbdev
Expand Down
2 changes: 2 additions & 0 deletions hierarchicalforecast/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@
'hierarchicalforecast/methods.py'),
'hierarchicalforecast.methods._reconcile_fcst_proportions': ( 'methods.html#_reconcile_fcst_proportions',
'hierarchicalforecast/methods.py'),
'hierarchicalforecast.methods.cov2corr': ( 'methods.html#cov2corr',
'hierarchicalforecast/methods.py'),
'hierarchicalforecast.methods.crossprod': ( 'methods.html#crossprod',
'hierarchicalforecast/methods.py'),
'hierarchicalforecast.methods.is_strictly_hierarchical': ( 'methods.html#is_strictly_hierarchical',
Expand Down
45 changes: 30 additions & 15 deletions hierarchicalforecast/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,35 +81,50 @@ def reconcile(self,
**Returns:**<br>
`y_tilde`: pd.DataFrame, with reconciled predictions.
"""
#----------------------------- Preliminary Wrangling/Protections -----------------------------#
# Check input's validity
if intervals_method not in ['normality', 'bootstrap', 'permbu']:
raise ValueError(f'Unkwon interval method: {intervals_method}')

if self.insample or (intervals_method in ['bootstrap', 'permbu']):
if Y_df is None:
raise Exception('you need to pass `Y_df`')

# Declare output names
drop_cols = ['ds', 'y'] if 'y' in Y_hat_df.columns else ['ds']
model_names = Y_hat_df.drop(columns=drop_cols, axis=1).columns.to_list()
# store pi names
pi_model_names = [name for name in model_names if ('-lo' in name or '-hi' in name)]
#remove prediction intervals
model_names = [name for name in model_names if name not in pi_model_names]

uids = Y_hat_df.index.unique()
# check if Y_hat_df has the same uids as S
if len(S.index.difference(uids)) > 0 or len(Y_hat_df.index.difference(S.index.unique())) > 0:
raise Exception('Summing matrix `S` and `Y_hat_df` do not have the same time series, please check.')
# same order of Y_hat_df to prevent errors

# Check Y_hat_df\S_df series difference
S_diff = len(S.index.difference(uids))
Y_hat_diff = len(Y_hat_df.index.difference(S.index.unique()))
if S_diff > 0 or Y_hat_diff > 0:
raise Exception(f'Check `S_df`, `Y_hat_df` series difference, S\Y_hat={S_diff}, Y_hat\S={Y_hat_diff}')

if Y_df is not None:
# Check Y_hat_df\Y_df series difference
Y_diff = len(Y_df.index.difference(uids))
Y_hat_diff = len(Y_hat_df.index.difference(Y_df.index.unique()))
if Y_diff > 0 or Y_hat_diff > 0:
raise Exception(f'Check `Y_hat_df`, `Y_df` series difference, Y_hat\Y={Y_hat_diff}, Y\Y_hat={Y_diff}')

# Same Y_hat_df/S_df/Y_df's unique_id order to prevent errors
S_ = S.loc[uids]


#---------------------------------------- Predictions ----------------------------------------#
# Initialize reconciler arguments
reconciler_args = dict(
S=S_.values.astype(np.float32),
idx_bottom=S_.index.get_indexer(S.columns),
tags={key: S_.index.get_indexer(val) for key, val in tags.items()}
)
# we need insample values if
# we are using a method that requires them
# or if we are performing boostrap
if self.insample or (intervals_method in ['bootstrap', 'permbu']):
if Y_df is None:
raise Exception('you need to pass `Y_df`')
# check if Y_hat_df has the same uids as Y_df
if len(Y_df.index.difference(uids)) > 0 or len(Y_hat_df.index.difference(Y_df.index.unique())) > 0:
raise Exception('Y_df` and `Y_hat_df` do not have the same time series, please check.')
if Y_df is not None:
reconciler_args['y_insample'] = Y_df.pivot(columns='ds', values='y').loc[uids].values.astype(np.float32)

fcsts = Y_hat_df.copy()
for reconcile_fn in self.reconcilers:
reconcile_fn_name = _build_fn_name(reconcile_fn)
Expand Down
64 changes: 55 additions & 9 deletions hierarchicalforecast/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import numpy as np
from numba import njit
from quadprog import solve_qp
from statsmodels.stats.moment_helpers import cov2corr

# %% ../nbs/methods.ipynb 5
def _reconcile(S: np.ndarray,
Expand Down Expand Up @@ -336,6 +335,25 @@ def crossprod(x):
return x.T @ x

# %% ../nbs/methods.ipynb 33
def cov2corr(cov, return_std=False):
""" convert covariance matrix to correlation matrix
**Parameters:**<br>
`cov`: array_like, 2d covariance matrix.<br>
`return_std`: bool=False, if True returned std.<br>
**Returns:**<br>
`corr`: ndarray (subclass) correlation matrix
"""
cov = np.asanyarray(cov)
std_ = np.sqrt(np.diag(cov))
corr = cov / np.outer(std_, std_)
if return_std:
return corr, std_
else:
return corr

# %% ../nbs/methods.ipynb 34
class MinTrace:
"""MinTrace Reconciliation Class.
Expand All @@ -350,6 +368,7 @@ class MinTrace:
**Parameters:**<br>
`method`: str, one of `ols`, `wls_struct`, `wls_var`, `mint_shrink`, `mint_cov`.<br>
`nonnegative`: bool, reconciled forecasts should be nonnegative?<br>
`mint_shr_ridge`: float, ridge numeric protection to MinTrace-shr covariance estimator.<br>
**References:**<br>
- [Wickramasuriya, S. L., Athanasopoulos, G., & Hyndman, R. J. (2019). \"Optimal forecast reconciliation for
Expand All @@ -361,10 +380,13 @@ class MinTrace:
"""
def __init__(self,
method: str,
nonnegative: bool = False):
nonnegative: bool = False,
mint_shr_ridge: Optional[float] = 2e-8):
self.method = method
self.nonnegative = nonnegative
self.insample = method in ['wls_var', 'mint_cov', 'mint_shrink']
if method == 'mint_shrink':
self.mint_shr_ridge = mint_shr_ridge

def reconcile(self,
S: np.ndarray,
Expand Down Expand Up @@ -398,27 +420,51 @@ def reconcile(self,
elif self.method == 'wls_struct':
W = np.diag(S @ np.ones((n_bottom,)))
elif self.method in res_methods:
#we need residuals with shape (obs, n_hiers)
# Residuals with shape (obs, n_hiers)
residuals = (y_insample - y_hat_insample).T
n, _ = residuals.shape

# Protection: against overfitted model
residuals_sum = np.sum(residuals, axis=0)
zero_residual_prc = np.abs(residuals_sum) < 1e-4
zero_residual_prc = np.mean(zero_residual_prc)
if zero_residual_prc > .98:
raise Exception(f'Insample residuals close to 0, zero_residual_prc={zero_residual_prc}. Check `Y_df`')

# Protection: cases where data is unavailable/nan
masked_res = np.ma.array(residuals, mask=np.isnan(residuals))
covm = np.ma.cov(masked_res, rowvar=False, allow_masked=True).data

if self.method == 'wls_var':
W = np.diag(np.diag(covm))
elif self.method == 'mint_cov':
W = covm
elif self.method == 'mint_shrink':
# Schäfer and Strimmer 2005, scale invariant shrinkage
# lasso or ridge might improve numerical stability but
# this version follows https://robjhyndman.com/papers/MinT.pdf
tar = np.diag(np.diag(covm))
corm = cov2corr(covm)
xs = np.divide(residuals, np.sqrt(np.diag(covm)))

# Protections: constant's correlation set to 0
# standardized residuals 0 where residual_std=0
corm, residual_std = cov2corr(covm, return_std=True)
corm = np.nan_to_num(corm, nan=0.0)
xs = np.divide(residuals, residual_std,
out=np.zeros_like(residuals), where=residual_std!=0)

xs = xs[~np.isnan(xs).any(axis=1), :]
v = (1 / (n * (n - 1))) * (crossprod(xs ** 2) - (1 / n) * (crossprod(xs) ** 2))
np.fill_diagonal(v, 0)

# Protection: constant's correlation set to 0
corapn = cov2corr(tar)
corapn = np.nan_to_num(corapn, nan=0.0)
d = (corm - corapn) ** 2
lmd = v.sum() / d.sum()
lmd = max(min(lmd, 1), 0)
W = lmd * tar + (1 - lmd) * covm

# Protection: final ridge diagonal protection
W = (lmd * tar + (1 - lmd) * covm) + self.mint_shr_ridge
else:
raise ValueError(f'Unkown reconciliation method {self.method}')

Expand Down Expand Up @@ -468,7 +514,7 @@ def reconcile(self,

__call__ = reconcile

# %% ../nbs/methods.ipynb 40
# %% ../nbs/methods.ipynb 42
class OptimalCombination(MinTrace):
"""Optimal Combination Reconciliation Class.
Expand Down Expand Up @@ -503,7 +549,7 @@ def __init__(self,
self.nonnegative = nonnegative
self.insample = False

# %% ../nbs/methods.ipynb 46
# %% ../nbs/methods.ipynb 48
@njit
def lasso(X: np.ndarray, y: np.ndarray,
lambda_reg: float, max_iters: int = 1_000,
Expand Down Expand Up @@ -535,7 +581,7 @@ def lasso(X: np.ndarray, y: np.ndarray,
#print(it)
return beta

# %% ../nbs/methods.ipynb 47
# %% ../nbs/methods.ipynb 49
class ERM:
"""Optimal Combination Reconciliation Class.
Expand Down
4 changes: 2 additions & 2 deletions hierarchicalforecast/probabilistic_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
import numpy as np
from scipy.stats import norm
from sklearn.preprocessing import OneHotEncoder
from statsmodels.stats.moment_helpers import cov2corr

from .methods import is_strictly_hierarchical

from .methods import is_strictly_hierarchical, cov2corr

# %% ../nbs/probabilistic_methods.ipynb 6
class Normality:
Expand Down
Loading

0 comments on commit 437e5b1

Please sign in to comment.