From 6772ea36fef07cce5e8af004e11810f35f3c4e2d Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Mon, 1 May 2017 08:49:45 -0500 Subject: [PATCH] ENH: Allow add_intercept for unknown dims --- dask_glm/tests/test_utils.py | 12 ++++++++++++ dask_glm/utils.py | 14 +++++++------- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/dask_glm/tests/test_utils.py b/dask_glm/tests/test_utils.py index 97b952b..858de12 100644 --- a/dask_glm/tests/test_utils.py +++ b/dask_glm/tests/test_utils.py @@ -30,6 +30,18 @@ def test_add_intercept_dask(): assert_eq(result, expected) +def test_add_intercept_unknown(): + dd = pytest.importorskip('dask.dataframe') + X = dd.from_array(da.from_array(np.zeros((4, 4)), chunks=(2, 4))).values + result = utils.add_intercept(X) + expected = da.from_array(np.array([ + [0, 0, 0, 0, 1], + [0, 0, 0, 0, 1], + [0, 0, 0, 0, 1], + [0, 0, 0, 0, 1], + ], dtype=X.dtype), chunks=2) + assert_eq(result, expected) + def test_sparse(): sparse = pytest.importorskip('sparse') from sparse.utils import assert_eq diff --git a/dask_glm/utils.py b/dask_glm/utils.py index e3a9919..2cf7bff 100644 --- a/dask_glm/utils.py +++ b/dask_glm/utils.py @@ -117,16 +117,16 @@ def add_intercept(X): return np.concatenate([X, np.ones((X.shape[0], 1))], axis=1) +def _add_intercept_block(x): + o = np.ones((len(x), 1), dtype=x.dtype) + return np.concatenate([x, o], axis=1) + + @dispatch(da.Array) def add_intercept(X): - if np.isnan(np.sum(X.shape)): - raise NotImplementedError("Can not add intercept to array with " - "unknown chunk shape") j, k = X.chunks - o = da.ones((X.shape[0], 1), chunks=(j, 1)) - # TODO: Needed this `.rechunk` for the solver to work - # Is this OK / correct? - X_i = da.concatenate([X, o], axis=1).rechunk((j, (k[0] + 1,))) + k2 = (__builtins__['sum'](k) + 1,) + X_i = X.map_blocks(_add_intercept_block, dtype=X.dtype, chunks=(j, k2)) return X_i