Skip to content

Commit

Permalink
[added feature] regularization term for inv(xtx) calculation (#63)
Browse files Browse the repository at this point in the history
* Updated score_computers.py for lambda_reg

* Updated traker.py to include a lambda_reg term in arguments

* minor fixes

---------

Co-authored-by: Kristian Georgiev <[email protected]>
  • Loading branch information
heale04 and kristian-georgiev authored Jan 17, 2024
1 parent ead7aa4 commit f491f57
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 3 deletions.
10 changes: 9 additions & 1 deletion trak/score_computers.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def __init__(
device: torch.device,
CUDA_MAX_DIM_SIZE: int = 20_000,
logging_level=logging.INFO,
lambda_reg: float = 0.0,
) -> None:
"""
Args:
Expand All @@ -132,11 +133,14 @@ def __init__(
Size of block for block-wise matmuls. Defaults to 100_000.
logging_level (logging level, optional):
Logging level for the logger. Defaults to logging.info.
lambda_reg (int):
regularization term for l2 reg on xtx
"""
super().__init__(dtype, device)
self.CUDA_MAX_DIM_SIZE = CUDA_MAX_DIM_SIZE
self.logger = logging.getLogger("ScoreComputer")
self.logger.setLevel(logging_level)
self.lambda_reg = lambda_reg

def get_xtx(self, grads: Tensor) -> Tensor:
self.proj_dim = grads.shape[1]
Expand All @@ -152,7 +156,11 @@ def get_xtx(self, grads: Tensor) -> Tensor:

def get_x_xtx_inv(self, grads: Tensor, xtx: Tensor) -> Tensor:
blocks = ch.split(grads, split_size_or_sections=self.CUDA_MAX_DIM_SIZE, dim=0)
xtx_inv = ch.linalg.inv(xtx.to(ch.float32))

xtx_reg = xtx + self.lambda_reg * torch.eye(
xtx.size(dim=0), device=xtx.device, dtype=xtx.dtype
)
xtx_inv = ch.linalg.inv(xtx_reg.to(ch.float32))

# center X^TX inverse a bit to avoid numerical issues when going to float16
xtx_inv /= xtx_inv.abs().mean()
Expand Down
12 changes: 10 additions & 2 deletions trak/traker.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def __init__(
proj_max_batch_size: int = 32,
projector_seed: int = 0,
grad_wrt: Optional[Iterable[str]] = None,
lambda_reg: float = 0.0,
) -> None:
"""
Expand Down Expand Up @@ -127,7 +128,10 @@ def __init__(
as they appear in the model's state dictionary. If None,
gradients are taken with respect to all model parameters.
Defaults to None.
lambda_reg (float):
The :math:`\ell_2` (ridge) regularization penalty added to the
:math:`XTX` term in score computers when computing the matrix
inverse :math:`(XTX)^{-1}`. Defaults to 0.
"""

self.model = model
Expand All @@ -136,6 +140,7 @@ def __init__(
self.device = device
self.dtype = ch.float16 if use_half_precision else ch.float32
self.grad_wrt = grad_wrt
self.lambda_reg = lambda_reg

logging.basicConfig()
self.logger = logging.getLogger("TRAK")
Expand Down Expand Up @@ -181,7 +186,10 @@ def __init__(
if score_computer is None:
score_computer = BasicScoreComputer
self.score_computer = score_computer(
dtype=self.dtype, device=self.device, logging_level=logging_level
dtype=self.dtype,
device=self.device,
logging_level=logging_level,
lambda_reg=self.lambda_reg,
)

metadata = {
Expand Down

0 comments on commit f491f57

Please sign in to comment.