From 61d4f079587215244c31a43968e4caf4aa28beab Mon Sep 17 00:00:00 2001 From: Kristian Georgiev Date: Thu, 2 Nov 2023 15:24:57 -0400 Subject: [PATCH] make type hints compatible with python 3.8 --- trak/gradient_computers.py | 2 +- trak/traker.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/trak/gradient_computers.py b/trak/gradient_computers.py index 0ce234e..b000e8b 100644 --- a/trak/gradient_computers.py +++ b/trak/gradient_computers.py @@ -93,7 +93,7 @@ def __init__( grad_dim: int, dtype: torch.dtype, device: torch.device, - grad_wrt: Optional[list[str]] = None, + grad_wrt: Optional[Iterable[str]] = None, ) -> None: """Initializes attributes, and loads model parameters. diff --git a/trak/traker.py b/trak/traker.py index 5a4cbbd..58f98f1 100644 --- a/trak/traker.py +++ b/trak/traker.py @@ -122,7 +122,7 @@ def __init__( for V100 GPUs, 40 for H100 GPUs. Defaults to 32. projector_seed (int): Random seed used by the projector. Defaults to 0. - grad_wrt (Optional[List[str]], optional): + grad_wrt (Optional[Iterable[str]], optional): If not None, the gradients will be computed only with respect to the parameters specified in this list. The list should contain the names of the parameters to compute gradients with respect to,