diff --git a/impyute/imputation/ts/locf.py b/impyute/imputation/ts/locf.py index 42521cf..fc5f449 100644 --- a/impyute/imputation/ts/locf.py +++ b/impyute/imputation/ts/locf.py @@ -3,7 +3,7 @@ from impyute.util import find_null from impyute.util import checks from impyute.util import preprocess - +from impyute.util.errors import BadInputError @preprocess @checks def locf(data, axis=0): @@ -33,6 +33,8 @@ def locf(data, axis=0): data = np.transpose(data) elif axis == 1: pass + else: + raise BadInputError("Error: Axis value is invalid, please use either 0 (row format) or 1 (column format)") null_xy = find_null(data) for x_i, y_i in null_xy: diff --git a/test/imputation/ts/test_locf.py b/test/imputation/ts/test_locf.py index 4c3b798..600d73c 100644 --- a/test/imputation/ts/test_locf.py +++ b/test/imputation/ts/test_locf.py @@ -2,6 +2,7 @@ import numpy as np import impyute as impy from impyute.util.testing import return_na_check +from impyute.util.errors import BadInputError SHAPE = (5, 5) @@ -32,4 +33,11 @@ def test_na_at_i_end(test_data): data = test_data(SHAPE, last_i, 3) actual = impy.locf(data, axis=1) data[last_i, 3] = data[last_i - 1, 3] - assert np.array_equal(actual, data) \ No newline at end of file + assert np.array_equal(actual, data) + + +def test_out_of_bounds(test_data): + """Check out of bounds error, should throw BadInputError for any axis outside [0,1]""" + data = test_data(SHAPE) + with np.testing.assert_raises(BadInputError): + impy.locf(data, axis=3)