Skip to content

Commit

Permalink
MultiFab Fixture Cleanup via FabArray::clear
Browse files Browse the repository at this point in the history
Using a context manager and calling clear ensures that we will
not hold device memory anymore once we hit `AMReX::Finalize`,
even in the situation where an exception is raised in a test.
This avoids segfaults for failing tests.
  • Loading branch information
ax3l committed Oct 21, 2023
1 parent 198cbb3 commit 7d53274
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 40 deletions.
4 changes: 4 additions & 0 deletions src/Base/MultiFab.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ void init_MultiFab(py::module &m)
;

py_FabArray_FArrayBox
// define
.def("clear", &FabArray<FArrayBox>::clear)
.def("ok", &FabArray<FArrayBox>::ok)

//.def("array", py::overload_cast< const MFIter& >(&FabArray<FArrayBox>::array))
//.def("const_array", &FabArray<FArrayBox>::const_array)
.def("array", [](FabArray<FArrayBox> & fa, MFIter const & mfi)
Expand Down
58 changes: 35 additions & 23 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,47 +85,59 @@ def distmap(boxarr):


@pytest.fixture(scope="function", params=list(itertools.product([1, 3], [0, 1])))
def make_mfab(boxarr, distmap, request):
def mfab(boxarr, distmap, request):
"""MultiFab that is either managed or device:
The MultiFab object itself is not a fixture because we want to avoid caching
it between amr.initialize/finalize calls of various tests.
https://github.com/pytest-dev/pytest/discussions/10387
https://github.com/pytest-dev/pytest/issues/5642#issuecomment-1279612764
"""

def create():
num_components = request.param[0]
num_ghost = request.param[1]
mfab = amr.MultiFab(boxarr, distmap, num_components, num_ghost)
mfab.set_val(0.0, 0, num_components)
return mfab
class MfabContextManager:
def __enter__(self):
num_components = request.param[0]
num_ghost = request.param[1]
self.mfab = amr.MultiFab(boxarr, distmap, num_components, num_ghost)
self.mfab.set_val(0.0, 0, num_components)
return self.mfab

return create
def __exit__(self, exc_type, exc_value, traceback):
self.mfab.clear()
del self.mfab

with MfabContextManager() as mfab:
yield mfab


@pytest.mark.skipif(
amr.Config.gpu_backend != "CUDA", reason="Requires AMReX_GPU_BACKEND=CUDA"
)
@pytest.fixture(scope="function", params=list(itertools.product([1, 3], [0, 1])))
def make_mfab_device(boxarr, distmap, request):
def mfab_device(boxarr, distmap, request):
"""MultiFab that resides purely on the device:
The MultiFab object itself is not a fixture because we want to avoid caching
it between amr.initialize/finalize calls of various tests.
https://github.com/pytest-dev/pytest/discussions/10387
https://github.com/pytest-dev/pytest/issues/5642#issuecomment-1279612764
"""

def create():
num_components = request.param[0]
num_ghost = request.param[1]
mfab = amr.MultiFab(
boxarr,
distmap,
num_components,
num_ghost,
amr.MFInfo().set_arena(amr.The_Device_Arena()),
)
mfab.set_val(0.0, 0, num_components)
return mfab

return create
class MfabDeviceContextManager:
def __enter__(self):
num_components = request.param[0]
num_ghost = request.param[1]
self.mfab = amr.MultiFab(
boxarr,
distmap,
num_components,
num_ghost,
amr.MFInfo().set_arena(amr.The_Device_Arena()),
)
self.mfab.set_val(0.0, 0, num_components)
return self.mfab

def __exit__(self, exc_type, exc_value, traceback):
self.mfab.clear()
del self.mfab

with MfabDeviceContextManager() as mfab:
yield mfab
26 changes: 9 additions & 17 deletions tests/test_multifab.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
import amrex.space3d as amr


def test_mfab_loop(make_mfab):
mfab = make_mfab()
def test_mfab_loop(mfab):
ngv = mfab.nGrowVect
print(f"\n mfab={mfab}, mfab.nGrowVect={ngv}")

Expand Down Expand Up @@ -78,8 +77,7 @@ def test_mfab_loop(make_mfab):
# TODO


def test_mfab_simple(make_mfab):
mfab = make_mfab()
def test_mfab_simple(mfab):
assert mfab.is_all_cell_centered
# assert(all(not mfab.is_nodal(i) for i in [-1, 0, 1, 2])) # -1??
assert all(not mfab.is_nodal(i) for i in [0, 1, 2])
Expand Down Expand Up @@ -144,8 +142,7 @@ def test_mfab_ops(boxarr, distmap, nghost):
np.testing.assert_allclose(dst.max(0), 150.0)


def test_mfab_mfiter(make_mfab):
mfab = make_mfab()
def test_mfab_mfiter(mfab):
assert iter(mfab).is_valid
assert iter(mfab).length == 8

Expand All @@ -159,8 +156,7 @@ def test_mfab_mfiter(make_mfab):
@pytest.mark.skipif(
amr.Config.gpu_backend != "CUDA", reason="Requires AMReX_GPU_BACKEND=CUDA"
)
def test_mfab_ops_cuda_numba(make_mfab_device):
mfab_device = make_mfab_device()
def test_mfab_ops_cuda_numba(mfab_device):
# https://numba.pydata.org/numba-doc/dev/cuda/cuda_array_interface.html
from numba import cuda

Expand Down Expand Up @@ -195,8 +191,7 @@ def set_to_three(array):
@pytest.mark.skipif(
amr.Config.gpu_backend != "CUDA", reason="Requires AMReX_GPU_BACKEND=CUDA"
)
def test_mfab_ops_cuda_cupy(make_mfab_device):
mfab_device = make_mfab_device()
def test_mfab_ops_cuda_cupy(mfab_device):
# https://docs.cupy.dev/en/stable/user_guide/interoperability.html
import cupy as cp
import cupyx.profiler
Expand Down Expand Up @@ -285,8 +280,7 @@ def set_to_seven(x):
@pytest.mark.skipif(
amr.Config.gpu_backend != "CUDA", reason="Requires AMReX_GPU_BACKEND=CUDA"
)
def test_mfab_ops_cuda_pytorch(make_mfab_device):
mfab_device = make_mfab_device()
def test_mfab_ops_cuda_pytorch(mfab_device):
# https://docs.cupy.dev/en/stable/user_guide/interoperability.html#pytorch
import torch

Expand All @@ -305,8 +299,8 @@ def test_mfab_ops_cuda_pytorch(make_mfab_device):
@pytest.mark.skipif(
amr.Config.gpu_backend != "CUDA", reason="Requires AMReX_GPU_BACKEND=CUDA"
)
def test_mfab_ops_cuda_cuml(make_mfab_device):
mfab_device = make_mfab_device() # noqa
def test_mfab_ops_cuda_cuml(mfab_device):
pass
# https://github.com/rapidsai/cuml
# https://github.com/rapidsai/cudf
# maybe better for particles as a dataframe test
Expand All @@ -322,9 +316,7 @@ def test_mfab_ops_cuda_cuml(make_mfab_device):
@pytest.mark.skipif(
amr.Config.gpu_backend != "CUDA", reason="Requires AMReX_GPU_BACKEND=CUDA"
)
def test_mfab_dtoh_copy(make_mfab_device):
mfab_device = make_mfab_device()

def test_mfab_dtoh_copy(mfab_device):
mfab_host = amr.MultiFab(
mfab_device.box_array(),
mfab_device.dm(),
Expand Down

0 comments on commit 7d53274

Please sign in to comment.