Skip to content

Commit

Permalink
rewrote penalties.py #6
Browse files Browse the repository at this point in the history
  • Loading branch information
ouslan committed Jan 2, 2025
1 parent c22b552 commit 4903ce1
Showing 1 changed file with 61 additions and 21 deletions.
82 changes: 61 additions & 21 deletions src/bayesgam/penalties.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,14 @@ def derivative(n:int, derivative:int=2, periodic:bool=False) -> sp.sparse.csc_ma
D[:, -2 * derivative : -derivative] += cols * (-1) ** derivative

# do symmetric operation on lower half of matrix
n_rows = int((n + 2 * derivative) / 2)
n_rows = (n + 2 * derivative) // 2
D[-n_rows:] = D[:n_rows][::-1, ::-1]

# keep only the center of the augmented matrix
D = D[derivative:-derivative, derivative:-derivative]
return D.dot(D.T).tocsc()

def apply_periodic_penalty(penalty_func, n:int, coef:np.ndarray, derivative:int=2):
def apply_periodic_penalty(penalty_func, n:int, coef:np.ndarray, derivative:int=2) -> sp.sparse.csc_matrix:
"""
Applies the penalty function for periodic features.
Expand All @@ -113,7 +113,7 @@ def apply_periodic_penalty(penalty_func, n:int, coef:np.ndarray, derivative:int=
"""
return penalty_func(n, coef, derivative=derivative, periodic=True)

def periodic(n:int, coef:np.ndarray, derivative:int=2, penalty_func=None):
def periodic(n:int, coef:np.ndarray, derivative:int=2, penalty_func=None) -> sp.sparce.csc_matrix:
"""
Wraps a penalty function to calculate the penalty for periodic features.
Expand Down Expand Up @@ -191,24 +191,7 @@ def monotonicity(n:int, coef:np.ndarray, increasing:bool=True) -> sp.sparse.csc_
D = sparse_diff(sp.sparse.identity(n).tocsc(), n=1) * mask
return D.dot(D.T).tocsc()

def monotonic_inc(n:int, coef:np.ndarray) -> sp.sparse.csc_matrix:
"""
Builds a penalty matrix for P-Splines with continuous features.
Penalizes violation of a monotonic increasing feature function.
Parameters
----------
n : int
number of splines
coef : array-like, coefficients of the feature function
Returns
-------
penalty matrix : sparse csc matrix of shape (n,n)
"""
return monotonicity(n, coef, increasing=True)

def convexity_(n, coef, convex=True):
def convexity(n:int, coef:np.ndarray, convex:bool=True) -> sp.sparse.csc_matrix:
"""
Builds a penalty matrix for P-Splines with continuous features.
Penalizes violation of convexity in the feature function.
Expand Down Expand Up @@ -241,6 +224,35 @@ def convexity_(n, coef, convex=True):
D = sparse_diff(sp.sparse.identity(n).tocsc(), n=2) * mask
return D.dot(D.T).tocsc()

def circular(n:int, coef:np.ndarray) -> sp.sparse.matrix:
"""
Builds a penalty matrix for P-Splines with continuous features.
Penalizes violation of a circular feature function.
Parameters
----------
n : int
number of splines
coef : unused
for compatibility with constraints
Returns
-------
penalty matrix : sparse csc matrix of shape (n,n)
"""
if n != len(coef.ravel()):
raise ValueError(f'dimension mismatch: expected n equals len(coef), but found n = {n}, coef.shape = {coef.shape}.')

if n==1:
# no first circular penalty for constant functions
return sp.sparse.csc_matrix(0.)

row = np.zeros(n)
row[0] = 1
row[-1] = -1
P = sp.sparse.vstack([row, sp.sparse.csc_matrix((n-2, n)), row[::-1]])
return P.tocsc()

def none(n:int) -> sp.sparse.csc_matrix:
"""
Build a matrix of zeros for features that should go unpenalized
Expand All @@ -258,3 +270,31 @@ def none(n:int) -> sp.sparse.csc_matrix:
"""
return sp.sparse.csc_matrix(np.zeros((n, n)))

def wrap_penalty(p, n, *args, fit_linear:bool, linear_penalty:float=0) -> sp.sparse.matrix:
"""
tool to account for unity penalty on the linear term of any feature.
example:
p = wrap_penalty(derivative, fit_linear=True)(n, coef)
Parameters
----------
p : callable.
penalty-matrix-generating function.
fit_linear : boolean.
whether the current feature has a linear term or not.
linear_penalty : float, default: 0.
penalty on the linear term
Returns
-------
wrapped_p : callable
modified penalty-matrix-generating function
"""
if fit_linear:
if n == 1:
return sp.sparse.block_diag([linear_penalty], format='csc')
else:
return sp.sparse.block_diag([linear_penalty, p(n - 1, *args)], format='csc')
else:
return p(n, *args)

0 comments on commit 4903ce1

Please sign in to comment.