Skip to content

Commit

Permalink
fix grads type check in iterative gradient computers
Browse files Browse the repository at this point in the history
Co-authored-by: TheaperDeng <junweid2.illinois.edu>
  • Loading branch information
kristian-georgiev committed Jan 17, 2024
1 parent e5cbe1e commit ead7aa4
Show file tree
Hide file tree
Showing 5 changed files with 149 additions and 19 deletions.
26 changes: 25 additions & 1 deletion tests/test_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,30 @@ def test_custom_model_output(tmp_path, cpu_proj):
)


def test_iterative_gradient_computer(tmp_path, cpu_proj):
from trak.gradient_computers import IterativeGradientComputer
from trak.projectors import NoOpProjector

model = resnet18()
N = 5
batch = ch.randn(N, 3, 32, 32), ch.randint(low=0, high=10, size=(N,))
traker = TRAKer(
model=model,
task="iterative_image_classification",
save_dir=tmp_path,
train_set_size=N,
logging_level=logging.DEBUG,
device="cpu",
use_half_precision=False,
projector=NoOpProjector(),
proj_dim=0,
gradient_computer=IterativeGradientComputer,
)
ckpt = model.state_dict()
traker.load_checkpoint(ckpt, model_id=0)
traker.featurize(batch, num_samples=N)


def test_grad_wrt_last_layer(tmp_path):
model = resnet18().eval()
N = 5
Expand Down Expand Up @@ -402,7 +426,7 @@ def test_grad_wrt_last_layer(tmp_path):
def test_grad_wrt_last_layer_cuda(tmp_path):
model = resnet18().cuda().eval()
N = 5
batch = ch.randn(N, 3, 32, 32).cuda(), ch.randint(low=0, high=10, size=(N,)).cuda()
batch = ch.randn(N, 3, 4, 4).cuda(), ch.randint(low=0, high=10, size=(N,)).cuda()
traker = TRAKer(
model=model,
task="image_classification",
Expand Down
17 changes: 12 additions & 5 deletions trak/gradient_computers.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,9 @@ def compute_per_sample_grad(self, batch: Iterable[Tensor]) -> Tensor:
batch of data
Returns:
Tensor:
gradients of the model output function of each sample in the
batch with respect to the model's parameters.
dict[Tensor]:
A dictionary where each key is a parameter name and the value is
the gradient tensor for that parameter.
"""
# taking the gradient wrt weights (second argument of get_output, hence argnums=1)
Expand Down Expand Up @@ -183,6 +183,9 @@ def compute_loss_grad(self, batch: Iterable[Tensor]) -> Tensor:
batch (Iterable[Tensor]):
batch of data
Returns:
Tensor:
The gradient of the loss with respect to the model output.
"""
return self.modelout_fn.get_out_to_loss_grad(
self.model, self.func_weights, self.func_buffers, batch
Expand Down Expand Up @@ -229,7 +232,7 @@ def compute_per_sample_grad(self, batch: Iterable[Tensor]) -> Tensor:
batch_size = batch[0].shape[0]
grads = ch.zeros(batch_size, self.grad_dim).to(batch[0].device)

margin = self.modelout_fn.get_output(self.model, *batch)
margin = self.modelout_fn.get_output(self.model, None, None, *batch)
for ind in range(batch_size):
grads[ind] = parameters_to_vector(
ch.autograd.grad(margin[ind], self.model_params, retain_graph=True)
Expand All @@ -254,5 +257,9 @@ def compute_loss_grad(self, batch: Iterable[Tensor]) -> Tensor:
Args:
batch (Iterable[Tensor]):
batch of data
Returns:
Tensor:
The gradient of the loss with respect to the model output.
"""
return self.modelout_fn.get_out_to_loss_grad(self.model, batch)
return self.modelout_fn.get_out_to_loss_grad(self.model, None, None, batch)
98 changes: 98 additions & 0 deletions trak/modelout_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
- :class:`.ImageClassificationModelOutput`
- :class:`.CLIPModelOutput`
- :class:`.TextClassificationModelOutput`
- :class:`.IterativeImageClassificationModelOutput`
These classes implement methods that transform input batches to the desired
model output (e.g. logits, loss, etc). See Sections 2 & 3 of `our paper
Expand Down Expand Up @@ -444,8 +445,105 @@ def get_out_to_loss_grad(
return (1 - ps).clone().detach().unsqueeze(-1)


class IterativeImageClassificationModelOutput(AbstractModelOutput):
"""Margin for (multiclass) image classification. See Section 3.3 of `our
paper <https://arxiv.org/abs/2303.14186>`_ for more details.
"""

def __init__(self, temperature: float = 1.0) -> None:
"""
Args:
temperature (float, optional): Temperature to use inside the
softmax for the out-to-loss function. Defaults to 1.
"""
super().__init__()
self.softmax = ch.nn.Softmax(-1)
self.loss_temperature = temperature

@staticmethod
def get_output(
model: Module,
weights: Iterable[Tensor],
buffers: Iterable[Tensor],
images: Tensor,
labels: Tensor,
) -> Tensor:
"""For a given input :math:`z=(x, y)` and model parameters :math:`\\theta`,
let :math:`p(z, \\theta)` be the softmax probability of the correct class.
This method implements the model output function
.. math::
\\log(\\frac{p(z, \\theta)}{1 - p(z, \\theta)}).
It uses functional models from torch.func (previously functorch) to make
the per-sample gradient computations (much) faster. For more details on
what functional models are, and how to use them, please refer to
https://pytorch.org/docs/stable/func.html and
https://pytorch.org/functorch/stable/notebooks/per_sample_grads.html.
Args:
model (torch.nn.Module):
torch model
weights (Iterable[Tensor]):
functorch model weights (added se we don't break abstraction)
buffers (Iterable[Tensor]):
functorch model buffers (added se we don't break abstraction)
images (Tensor):
input images
labels (Tensor):
input labels
Returns:
Tensor:
model output for the given image-label pair :math:`z`
"""
logits = model(images)
bindex = ch.arange(logits.shape[0]).to(logits.device, non_blocking=False)
logits_correct = logits[bindex, labels]

cloned_logits = logits.clone()
# remove the logits of the correct labels from the sum
# in logsumexp by setting to -ch.inf
cloned_logits[bindex, labels] = ch.tensor(
-ch.inf, device=logits.device, dtype=logits.dtype
)

margins = logits_correct - cloned_logits.logsumexp(dim=-1)
return margins

def get_out_to_loss_grad(
self, model, weights, buffers, batch: Iterable[Tensor]
) -> Tensor:
"""Computes the (reweighting term Q in the paper)
Args:
model (torch.nn.Module):
torch model
weights (Iterable[Tensor]):
functorch model weights
buffers (Iterable[Tensor]):
functorch model buffers
batch (Iterable[Tensor]):
input batch
Returns:
Tensor:
out-to-loss (reweighting term) for the input batch
"""
images, labels = batch
logits = model(images)
# here we are directly implementing the gradient instead of relying on autodiff to do
# that for us
ps = self.softmax(logits / self.loss_temperature)[
ch.arange(logits.size(0)), labels
]
return (1 - ps).clone().detach().unsqueeze(-1)


TASK_TO_MODELOUT = {
"image_classification": ImageClassificationModelOutput,
"clip": CLIPModelOutput,
"text_classification": TextClassificationModelOutput,
"iterative_image_classification": IterativeImageClassificationModelOutput,
}
19 changes: 11 additions & 8 deletions trak/projectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,9 @@ def project(self, grads: Tensor, model_id: int) -> Tensor:
Returns:
Tensor: the (non-)projected gradients
"""
return vectorize(grads, device=self.device)
if isinstance(grads, dict):
grads = vectorize(grads, device=self.device)
return grads

def free_memory(self):
"""A no-op method."""
Expand Down Expand Up @@ -190,7 +192,9 @@ def generate_sketch_matrix(self):
raise KeyError(f"Projection type {self.proj_type} not recognized.")

def project(self, grads: Tensor, model_id: int) -> Tensor:
grads = vectorize(grads, device=self.device)
if isinstance(grads, dict):
grads = vectorize(grads, device=self.device)

grads = grads.to(dtype=self.dtype)
if model_id != self.model_id:
self.model_id = model_id
Expand Down Expand Up @@ -254,7 +258,7 @@ def free_memory(self):
def get_generator_states(self):
self.generator_states = []
self.seeds = []
self.jl_size = self.grad_dim * self.block_size
self.jl_size = self.grad_dim * self.block_size

for i in range(self.num_blocks):
s = self.seed + int(1e3) * i + int(1e5) * self.model_id
Expand Down Expand Up @@ -283,7 +287,8 @@ def generate_sketch_matrix(self, generator_state):
raise KeyError(f"Projection type {self.proj_type} not recognized.")

def project(self, grads: Tensor, model_id: int) -> Tensor:
grads = vectorize(grads, device=self.device)
if isinstance(grads, dict):
grads = vectorize(grads, device=self.device)
grads = grads.to(dtype=self.dtype)
sketch = ch.zeros(
size=(grads.size(0), self.proj_dim), dtype=self.dtype, device=self.device
Expand Down Expand Up @@ -380,10 +385,10 @@ def project(
self,
grads: Union[dict, Tensor],
model_id: int,
is_grads_dict: bool = True,
) -> Tensor:
if is_grads_dict:
if isinstance(grads, dict):
grads = vectorize(grads, device=self.device)

batch_size = grads.shape[0]

effective_batch_size = 32
Expand Down Expand Up @@ -486,7 +491,6 @@ def project(self, grads, model_id):
self.projector_per_chunk[projector_index].project(
self.ch_input[:, :pointer].contiguous(),
model_id=model_id,
is_grads_dict=False,
)
)
# reset counter
Expand All @@ -506,7 +510,6 @@ def project(self, grads, model_id):
self.projector_per_chunk[projector_index].project(
self.ch_input[:actual_bs, :pointer].contiguous(),
model_id=model_id,
is_grads_dict=False,
)
)

Expand Down
8 changes: 3 additions & 5 deletions trak/traker.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,9 @@ def __init__(
model to use for TRAK
task (Union[AbstractModelOutput, str]):
Type of model that TRAK will be ran on. Accepts either one of
the following strings:
- :code:`image_classification`
- :code:`text_classification`
- :code:`clip`
or an instance of some implementation of the abstract class
the following strings: 1) :code:`image_classification` 2)
:code:`text_classification` 3) :code:`clip` or an instance of
some implementation of the abstract class
:class:`.AbstractModelOutput`.
train_set_size (int):
Size of the train set that TRAK is featurizing
Expand Down

0 comments on commit ead7aa4

Please sign in to comment.