diff --git a/trak/gradient_computers.py b/trak/gradient_computers.py index 0ce234e..b000e8b 100644 --- a/trak/gradient_computers.py +++ b/trak/gradient_computers.py @@ -93,7 +93,7 @@ def __init__( grad_dim: int, dtype: torch.dtype, device: torch.device, - grad_wrt: Optional[list[str]] = None, + grad_wrt: Optional[Iterable[str]] = None, ) -> None: """Initializes attributes, and loads model parameters. diff --git a/trak/traker.py b/trak/traker.py index 5a4cbbd..58f98f1 100644 --- a/trak/traker.py +++ b/trak/traker.py @@ -122,7 +122,7 @@ def __init__( for V100 GPUs, 40 for H100 GPUs. Defaults to 32. projector_seed (int): Random seed used by the projector. Defaults to 0. - grad_wrt (Optional[List[str]], optional): + grad_wrt (Optional[Iterable[str]], optional): If not None, the gradients will be computed only with respect to the parameters specified in this list. The list should contain the names of the parameters to compute gradients with respect to,