From ead7aa4bde0e7b9c00971543e65b5597583e616d Mon Sep 17 00:00:00 2001 From: Kristian Georgiev Date: Wed, 17 Jan 2024 13:02:25 -0500 Subject: [PATCH] fix grads type check in iterative gradient computers Co-authored-by: TheaperDeng --- tests/test_class.py | 26 +++++++++- trak/gradient_computers.py | 17 +++++-- trak/modelout_functions.py | 98 ++++++++++++++++++++++++++++++++++++++ trak/projectors.py | 19 ++++---- trak/traker.py | 8 ++-- 5 files changed, 149 insertions(+), 19 deletions(-) diff --git a/tests/test_class.py b/tests/test_class.py index b4bf2e3..9f4883d 100644 --- a/tests/test_class.py +++ b/tests/test_class.py @@ -374,6 +374,30 @@ def test_custom_model_output(tmp_path, cpu_proj): ) +def test_iterative_gradient_computer(tmp_path, cpu_proj): + from trak.gradient_computers import IterativeGradientComputer + from trak.projectors import NoOpProjector + + model = resnet18() + N = 5 + batch = ch.randn(N, 3, 32, 32), ch.randint(low=0, high=10, size=(N,)) + traker = TRAKer( + model=model, + task="iterative_image_classification", + save_dir=tmp_path, + train_set_size=N, + logging_level=logging.DEBUG, + device="cpu", + use_half_precision=False, + projector=NoOpProjector(), + proj_dim=0, + gradient_computer=IterativeGradientComputer, + ) + ckpt = model.state_dict() + traker.load_checkpoint(ckpt, model_id=0) + traker.featurize(batch, num_samples=N) + + def test_grad_wrt_last_layer(tmp_path): model = resnet18().eval() N = 5 @@ -402,7 +426,7 @@ def test_grad_wrt_last_layer(tmp_path): def test_grad_wrt_last_layer_cuda(tmp_path): model = resnet18().cuda().eval() N = 5 - batch = ch.randn(N, 3, 32, 32).cuda(), ch.randint(low=0, high=10, size=(N,)).cuda() + batch = ch.randn(N, 3, 4, 4).cuda(), ch.randint(low=0, high=10, size=(N,)).cuda() traker = TRAKer( model=model, task="image_classification", diff --git a/trak/gradient_computers.py b/trak/gradient_computers.py index b000e8b..f2aa0c0 100644 --- a/trak/gradient_computers.py +++ b/trak/gradient_computers.py @@ -136,9 +136,9 @@ def compute_per_sample_grad(self, batch: Iterable[Tensor]) -> Tensor: batch of data Returns: - Tensor: - gradients of the model output function of each sample in the - batch with respect to the model's parameters. + dict[Tensor]: + A dictionary where each key is a parameter name and the value is + the gradient tensor for that parameter. """ # taking the gradient wrt weights (second argument of get_output, hence argnums=1) @@ -183,6 +183,9 @@ def compute_loss_grad(self, batch: Iterable[Tensor]) -> Tensor: batch (Iterable[Tensor]): batch of data + Returns: + Tensor: + The gradient of the loss with respect to the model output. """ return self.modelout_fn.get_out_to_loss_grad( self.model, self.func_weights, self.func_buffers, batch @@ -229,7 +232,7 @@ def compute_per_sample_grad(self, batch: Iterable[Tensor]) -> Tensor: batch_size = batch[0].shape[0] grads = ch.zeros(batch_size, self.grad_dim).to(batch[0].device) - margin = self.modelout_fn.get_output(self.model, *batch) + margin = self.modelout_fn.get_output(self.model, None, None, *batch) for ind in range(batch_size): grads[ind] = parameters_to_vector( ch.autograd.grad(margin[ind], self.model_params, retain_graph=True) @@ -254,5 +257,9 @@ def compute_loss_grad(self, batch: Iterable[Tensor]) -> Tensor: Args: batch (Iterable[Tensor]): batch of data + + Returns: + Tensor: + The gradient of the loss with respect to the model output. """ - return self.modelout_fn.get_out_to_loss_grad(self.model, batch) + return self.modelout_fn.get_out_to_loss_grad(self.model, None, None, batch) diff --git a/trak/modelout_functions.py b/trak/modelout_functions.py index 47ef9b7..c0a7623 100644 --- a/trak/modelout_functions.py +++ b/trak/modelout_functions.py @@ -5,6 +5,7 @@ - :class:`.ImageClassificationModelOutput` - :class:`.CLIPModelOutput` - :class:`.TextClassificationModelOutput` +- :class:`.IterativeImageClassificationModelOutput` These classes implement methods that transform input batches to the desired model output (e.g. logits, loss, etc). See Sections 2 & 3 of `our paper @@ -444,8 +445,105 @@ def get_out_to_loss_grad( return (1 - ps).clone().detach().unsqueeze(-1) +class IterativeImageClassificationModelOutput(AbstractModelOutput): + """Margin for (multiclass) image classification. See Section 3.3 of `our + paper `_ for more details. + """ + + def __init__(self, temperature: float = 1.0) -> None: + """ + Args: + temperature (float, optional): Temperature to use inside the + softmax for the out-to-loss function. Defaults to 1. + """ + super().__init__() + self.softmax = ch.nn.Softmax(-1) + self.loss_temperature = temperature + + @staticmethod + def get_output( + model: Module, + weights: Iterable[Tensor], + buffers: Iterable[Tensor], + images: Tensor, + labels: Tensor, + ) -> Tensor: + """For a given input :math:`z=(x, y)` and model parameters :math:`\\theta`, + let :math:`p(z, \\theta)` be the softmax probability of the correct class. + This method implements the model output function + + .. math:: + + \\log(\\frac{p(z, \\theta)}{1 - p(z, \\theta)}). + + It uses functional models from torch.func (previously functorch) to make + the per-sample gradient computations (much) faster. For more details on + what functional models are, and how to use them, please refer to + https://pytorch.org/docs/stable/func.html and + https://pytorch.org/functorch/stable/notebooks/per_sample_grads.html. + + Args: + model (torch.nn.Module): + torch model + weights (Iterable[Tensor]): + functorch model weights (added se we don't break abstraction) + buffers (Iterable[Tensor]): + functorch model buffers (added se we don't break abstraction) + images (Tensor): + input images + labels (Tensor): + input labels + + Returns: + Tensor: + model output for the given image-label pair :math:`z` + """ + logits = model(images) + bindex = ch.arange(logits.shape[0]).to(logits.device, non_blocking=False) + logits_correct = logits[bindex, labels] + + cloned_logits = logits.clone() + # remove the logits of the correct labels from the sum + # in logsumexp by setting to -ch.inf + cloned_logits[bindex, labels] = ch.tensor( + -ch.inf, device=logits.device, dtype=logits.dtype + ) + + margins = logits_correct - cloned_logits.logsumexp(dim=-1) + return margins + + def get_out_to_loss_grad( + self, model, weights, buffers, batch: Iterable[Tensor] + ) -> Tensor: + """Computes the (reweighting term Q in the paper) + + Args: + model (torch.nn.Module): + torch model + weights (Iterable[Tensor]): + functorch model weights + buffers (Iterable[Tensor]): + functorch model buffers + batch (Iterable[Tensor]): + input batch + + Returns: + Tensor: + out-to-loss (reweighting term) for the input batch + """ + images, labels = batch + logits = model(images) + # here we are directly implementing the gradient instead of relying on autodiff to do + # that for us + ps = self.softmax(logits / self.loss_temperature)[ + ch.arange(logits.size(0)), labels + ] + return (1 - ps).clone().detach().unsqueeze(-1) + + TASK_TO_MODELOUT = { "image_classification": ImageClassificationModelOutput, "clip": CLIPModelOutput, "text_classification": TextClassificationModelOutput, + "iterative_image_classification": IterativeImageClassificationModelOutput, } diff --git a/trak/projectors.py b/trak/projectors.py index 18ec010..b3a1e3a 100644 --- a/trak/projectors.py +++ b/trak/projectors.py @@ -115,7 +115,9 @@ def project(self, grads: Tensor, model_id: int) -> Tensor: Returns: Tensor: the (non-)projected gradients """ - return vectorize(grads, device=self.device) + if isinstance(grads, dict): + grads = vectorize(grads, device=self.device) + return grads def free_memory(self): """A no-op method.""" @@ -190,7 +192,9 @@ def generate_sketch_matrix(self): raise KeyError(f"Projection type {self.proj_type} not recognized.") def project(self, grads: Tensor, model_id: int) -> Tensor: - grads = vectorize(grads, device=self.device) + if isinstance(grads, dict): + grads = vectorize(grads, device=self.device) + grads = grads.to(dtype=self.dtype) if model_id != self.model_id: self.model_id = model_id @@ -254,7 +258,7 @@ def free_memory(self): def get_generator_states(self): self.generator_states = [] self.seeds = [] - self.jl_size = self.grad_dim * self.block_size + self.jl_size = self.grad_dim * self.block_size for i in range(self.num_blocks): s = self.seed + int(1e3) * i + int(1e5) * self.model_id @@ -283,7 +287,8 @@ def generate_sketch_matrix(self, generator_state): raise KeyError(f"Projection type {self.proj_type} not recognized.") def project(self, grads: Tensor, model_id: int) -> Tensor: - grads = vectorize(grads, device=self.device) + if isinstance(grads, dict): + grads = vectorize(grads, device=self.device) grads = grads.to(dtype=self.dtype) sketch = ch.zeros( size=(grads.size(0), self.proj_dim), dtype=self.dtype, device=self.device @@ -380,10 +385,10 @@ def project( self, grads: Union[dict, Tensor], model_id: int, - is_grads_dict: bool = True, ) -> Tensor: - if is_grads_dict: + if isinstance(grads, dict): grads = vectorize(grads, device=self.device) + batch_size = grads.shape[0] effective_batch_size = 32 @@ -486,7 +491,6 @@ def project(self, grads, model_id): self.projector_per_chunk[projector_index].project( self.ch_input[:, :pointer].contiguous(), model_id=model_id, - is_grads_dict=False, ) ) # reset counter @@ -506,7 +510,6 @@ def project(self, grads, model_id): self.projector_per_chunk[projector_index].project( self.ch_input[:actual_bs, :pointer].contiguous(), model_id=model_id, - is_grads_dict=False, ) ) diff --git a/trak/traker.py b/trak/traker.py index 76d319a..21e30f1 100644 --- a/trak/traker.py +++ b/trak/traker.py @@ -72,11 +72,9 @@ def __init__( model to use for TRAK task (Union[AbstractModelOutput, str]): Type of model that TRAK will be ran on. Accepts either one of - the following strings: - - :code:`image_classification` - - :code:`text_classification` - - :code:`clip` - or an instance of some implementation of the abstract class + the following strings: 1) :code:`image_classification` 2) + :code:`text_classification` 3) :code:`clip` or an instance of + some implementation of the abstract class :class:`.AbstractModelOutput`. train_set_size (int): Size of the train set that TRAK is featurizing