Skip to content

Commit

Permalink
Mock Fabric os.environ variables to be sure to have a SingleDeviceStr…
Browse files Browse the repository at this point in the history
…ategy Fabric object (#250)
  • Loading branch information
belerico authored Apr 2, 2024
1 parent 9f557c6 commit 875166a
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 1 deletion.
11 changes: 10 additions & 1 deletion sheeprl/utils/fabric.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from unittest import mock

from lightning.fabric import Fabric
from lightning.fabric.accelerators import XLAAccelerator
from lightning.fabric.strategies import SingleDeviceStrategy, SingleDeviceXLAStrategy
Expand All @@ -23,4 +25,11 @@ def get_single_device_fabric(fabric: Fabric) -> Fabric:
checkpoint_io=None,
precision=fabric._precision,
)
return Fabric(strategy=strategy)
with mock.patch.dict("os.environ") as mocked_os_environ:
mocked_os_environ.pop("LT_DEVICES", None)
mocked_os_environ.pop("LT_STRATEGY", None)
mocked_os_environ.pop("LT_NUM_NODES", None)
mocked_os_environ.pop("LT_PRECISION", None)
mocked_os_environ.pop("LT_ACCELERATOR", None)
fabric = Fabric(strategy=strategy)
return fabric
13 changes: 13 additions & 0 deletions tests/test_utils/test_fabric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from lightning import Fabric
from lightning.fabric.strategies import SingleDeviceStrategy

from sheeprl.utils.fabric import get_single_device_fabric


def test_get_single_device_fabric():
fabric = Fabric(devices=2, accelerator="cpu", precision=16)
single_device_fabric = get_single_device_fabric(fabric)
assert single_device_fabric.device == fabric.device
assert single_device_fabric._precision == fabric._precision
assert single_device_fabric.accelerator == fabric.accelerator
assert isinstance(single_device_fabric.strategy, SingleDeviceStrategy)

0 comments on commit 875166a

Please sign in to comment.