Skip to content

Commit

Permalink
test_mfab_dtoh_copy: Clear MFabs
Browse files Browse the repository at this point in the history
Clear out memory safely on runtime errors.
  • Loading branch information
ax3l committed Oct 21, 2023
1 parent 7d53274 commit 15a91d2
Showing 1 changed file with 49 additions and 39 deletions.
88 changes: 49 additions & 39 deletions tests/test_multifab.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,44 +317,54 @@ def test_mfab_ops_cuda_cuml(mfab_device):
amr.Config.gpu_backend != "CUDA", reason="Requires AMReX_GPU_BACKEND=CUDA"
)
def test_mfab_dtoh_copy(mfab_device):
mfab_host = amr.MultiFab(
mfab_device.box_array(),
mfab_device.dm(),
mfab_device.n_comp(),
mfab_device.n_grow_vect(),
amr.MFInfo().set_arena(amr.The_Pinned_Arena()),
)
mfab_host.set_val(42.0)

amr.dtoh_memcpy(mfab_host, mfab_device)

# assert all are 0.0 on host
host_min = mfab_host.min(0)
host_max = mfab_host.max(0)
assert host_min == host_max
assert host_max == 0.0

dev_val = 11.0
mfab_host.set_val(dev_val)
amr.dtoh_memcpy(mfab_device, mfab_host)

# assert all are 11.0 on device
for n in range(mfab_device.n_comp()):
assert mfab_device.min(comp=n) == dev_val
assert mfab_device.max(comp=n) == dev_val

# numpy bindings (w/ copy)
local_boxes_host = mfab_device.to_numpy(copy=True)
assert max([np.max(box) for box in local_boxes_host]) == dev_val

# numpy bindings (w/ copy)
for mfi in mfab_device:
marr = mfab_device.array(mfi).to_numpy(copy=True)
assert np.min(marr) >= dev_val
assert np.max(marr) <= dev_val
class MfabPinnedContextManager:
def __enter__(self):
self.mfab = amr.MultiFab(
mfab_device.box_array(),
mfab_device.dm(),
mfab_device.n_comp(),
mfab_device.n_grow_vect(),
amr.MFInfo().set_arena(amr.The_Pinned_Arena()),
)
return self.mfab

def __exit__(self, exc_type, exc_value, traceback):
self.mfab.clear()
del self.mfab

with MfabPinnedContextManager() as mfab_host:
mfab_host.set_val(42.0)

amr.dtoh_memcpy(mfab_host, mfab_device)

# assert all are 0.0 on host
host_min = mfab_host.min(0)
host_max = mfab_host.max(0)
assert host_min == host_max
assert host_max == 0.0

dev_val = 11.0
mfab_host.set_val(dev_val)
amr.htod_memcpy(mfab_device, mfab_host)

# assert all are 11.0 on device
for n in range(mfab_device.n_comp()):
assert mfab_device.min(comp=n) == dev_val
assert mfab_device.max(comp=n) == dev_val

# numpy bindings (w/ copy)
local_boxes_host = mfab_device.to_numpy(copy=True)
assert max([np.max(box) for box in local_boxes_host]) == dev_val
del local_boxes_host

# numpy bindings (w/ copy)
for mfi in mfab_device:
marr = mfab_device.array(mfi).to_numpy(copy=True)
assert np.min(marr) >= dev_val
assert np.max(marr) <= dev_val

# cupy bindings (w/o copy)
import cupy as cp
# cupy bindings (w/o copy)
import cupy as cp

local_boxes_device = mfab_device.to_cupy()
assert max([cp.max(box) for box in local_boxes_device]) == dev_val
local_boxes_device = mfab_device.to_cupy()
assert max([cp.max(box) for box in local_boxes_device]) == dev_val

0 comments on commit 15a91d2

Please sign in to comment.