Skip to content

Commit

Permalink
added checks for input shape in TV2DNorm
Browse files Browse the repository at this point in the history
  • Loading branch information
Salman Naqvi committed Oct 5, 2023
1 parent 1b7f583 commit 9d1d73a
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion scico/functional/_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ class TV2DNorm(Functional):
has_eval = True
has_prox = True

def __init__(self, dims, tau: float = 1.0):
def __init__(self, dims: Tuple[int, int], tau: float = 1.0):
r"""
Args:
tau: Parameter :math:`\tau` in the norm definition.
Expand All @@ -511,6 +511,7 @@ def __init__(self, dims, tau: float = 1.0):

def __call__(self, x: Union[Array, BlockArray]) -> float:
r"""Return the :math:`\ell_{TV}` norm of an array."""
assert x.shape == self.dims
y = 0
gradOp = FiniteDifference(self.dims, input_dtype=x.dtype, circular=True)
grads = gradOp @ x
Expand All @@ -532,6 +533,7 @@ def prox(
kwargs: Additional arguments that may be used by derived
classes.
"""
assert x.shape == self.dims
D = 2
K = 2*D
thresh = snp.sqrt(2) * K * self.tau * lam
Expand Down

0 comments on commit 9d1d73a

Please sign in to comment.