Skip to content

Commit

Permalink
Update VELOVI::train (#1219)
Browse files Browse the repository at this point in the history
Remove argument `use_gpu` and rely on
`accelerator` and `devices`, instead.
  • Loading branch information
WeilerP authored Mar 17, 2024
1 parent 0a5e5c8 commit 8997c03
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions scvelo/tools/_vi_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@ def train(
max_epochs: Optional[int] = 500,
lr: float = 1e-2,
weight_decay: float = 1e-2,
use_gpu: Optional[Union[str, int, bool]] = None,
accelerator: str = "auto",
devices: Union[int, List[int], str] = "auto",
train_size: float = 0.9,
Expand All @@ -149,9 +148,14 @@ def train(
Learning rate for optimization
weight_decay
Weight decay for optimization
use_gpu
Use default GPU if available (if None or True), or index of GPU to use (if int),
or name of GPU (if str, e.g., `'cuda:0'`), or use CPU (if False).
accelerator
Supports passing different accelerator types `("cpu", "gpu", "tpu", "ipu", "hpu", "mps, "auto")` as well as
custom accelerator instances.
devices
The devices to use. Can be set to a non-negative index (`int` or `str`), a sequence of device indices
(`list` or comma-separated `str`), the value `-1` to indicate all available devices, or `"auto"` for
automatic selection based on the chosen `accelerator`. If set to `"auto"` and `accelerator` is not
determined to be `"cpu"`, then `devices` will be set to the first available device.
train_size
Size of training set in the range [0.0, 1.0].
validation_size
Expand Down Expand Up @@ -195,7 +199,8 @@ def train(
training_plan=training_plan,
data_splitter=data_splitter,
max_epochs=max_epochs,
use_gpu=use_gpu,
accelerator=accelerator,
devices=devices,
**trainer_kwargs,
)
return runner()
Expand Down

0 comments on commit 8997c03

Please sign in to comment.