Skip to content

Commit

Permalink
finished utils.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ouslan committed Jan 2, 2025
1 parent d3d5665 commit c22b552
Showing 1 changed file with 99 additions and 5 deletions.
104 changes: 99 additions & 5 deletions src/bayesgam/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def check_array(

return array

def check_y(y:np.ndarray, link:object, dist:str, min_samples:int=1, verbose:bool=True) -> np.ndarray:
def check_y(y:np.ndarray, link, dist:str, min_samples:int=1, verbose:bool=True) -> np.ndarray:
"""
tool to ensure that the targets:
- are in the domain of the link function
Expand Down Expand Up @@ -175,6 +175,94 @@ def check_y(y:np.ndarray, link:object, dist:str, min_samples:int=1, verbose:bool
)
return y

# TODO: Depricate the categorical data leave the user to code the data
def check_X(
X,
n_feats=None,
min_samples=1,
edge_knots=None,
dtypes=None,
features=None,
verbose=True,
) -> np.ndarray:
"""
tool to ensure that X:
- is 2 dimensional
- contains float-compatible data-types
- has at least min_samples
- has n_feats
- has categorical features in the right range
- is finite
Parameters
----------
X : array-like
n_feats : int. default: None
represents number of features that X should have.
not enforced if n_feats is None.
min_samples : int, default: 1
edge_knots : list of arrays, default: None
dtypes : list of strings, default: None
features : list of ints,
which features are considered by the model
verbose : bool, default: True
whether to print warnings
Returns
-------
X : array with ndims == 2 containing validated X-data
"""
# check all features are there
if bool(features):
features = flatten(features)
max_feat = max(flatten(features))

if n_feats is None:
n_feats = max_feat

n_feats = max(n_feats, max_feat)

# basic diagnostics
X = check_array(
X,
force_2d=True,
n_feats=n_feats,
min_samples=min_samples,
name="X data",
verbose=verbose,
)

# check our categorical data has no new categories
if (edge_knots is not None) and (dtypes is not None) and (features is not None):
# get a flattened list of tuples
edge_knots = flatten(edge_knots)[::-1]
dtypes = flatten(dtypes)
if len(edge_knots) % 2 == 0: # sanity check
raise ValueError("Fail Sanity Check")
# form pairs
n = len(edge_knots) // 2
edge_knots = [(edge_knots.pop(), edge_knots.pop()) for _ in range(n)]

# check each categorical term
for i, ek in enumerate(edge_knots):
dt = dtypes[i]
feature = features[i]
x = X[:, feature]

if dt == "categorical":
dt_min = ek[0]
dt_max = ek[-1]
if (np.unique(x) < dt_min).any() or (np.unique(x) > dt_max).any():
dt_min += 0.5
dt_max -= 0.5
raise ValueError(
"X data is out of domain for categorical "
f"feature {i}. Expected data on [{dt_min}, {dt_max}], "
f"but found data on [{x.min()}, {x.max()}]"
)

return X

def check_X_y(X:np.ndarray, y:np.ndarray) -> None:
"""
tool to ensure input and output data have the same number of samples
Expand Down Expand Up @@ -209,7 +297,12 @@ def check_lengths(*arrays) -> None:
if len(array) != first_length:
raise ValueError(f"Array at index {i} has a different length ({len(array)}). Expected length: {first_length}.")

def check_param(param, param_name:str, dtype, constraint=None, iterable:bool=True, max_depth:int=2):
def check_param(
param:np.ndarray,
param_name:str,
dtype, constraint=None,
iterable:bool=True,
max_depth:int=2) -> np.ndarray:
"""
checks the dtype of a parameter,
and whether it satisfies a numerical contraint
Expand All @@ -229,7 +322,7 @@ def check_param(param, param_name:str, dtype, constraint=None, iterable:bool=Tru
only used if iterable == True
Returns
-------
list of validated and converted parameter(s)
Array of validated and converted parameter(s)
"""
msg = []
msg.append(param_name + " must be " + dtype)
Expand Down Expand Up @@ -331,6 +424,7 @@ def gen_edge_knots(data:np.ndarray, dtype:str, verbose:bool=True) -> np.ndarray:
)
return knots

# TODO: Check there might be some unbounded values
def b_spline_basis(
x,
edge_knots,
Expand All @@ -339,7 +433,7 @@ def b_spline_basis(
sparse=True,
periodic=True,
verbose=True,
):
) -> sp.sparse.csc_matrix:
"""
Generate B-spline basis functions using vectorized De Boor recursion.
The basis functions extrapolate linearly past the end-knots.
Expand Down Expand Up @@ -707,4 +801,4 @@ def __call__(self, data_list):
if self.ul:
formatted_rows.insert(1, self.row(self.ul))

return "\n".join(formatted_rows)
return "\n".join(formatted_rows)

0 comments on commit c22b552

Please sign in to comment.