Skip to content

Commit

Permalink
Merge pull request #190 from boutproject/reloading-geometry
Browse files Browse the repository at this point in the history
Allow geometry to be changed when reloading an xBOUT-saved Dataset
  • Loading branch information
johnomotani authored Apr 8, 2021
2 parents fa3ab8b + b2cc9ef commit 98fbb5d
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 13 deletions.
21 changes: 20 additions & 1 deletion xbout/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,26 @@ def attrs_remove_section(obj, section):
ds = _add_options(ds, inputfilepath)

# If geometry was set, apply geometry again
if "geometry" in ds.attrs:
if geometry is not None:
if "geometry" != ds.attrs.get("geometry", None):
warn(
f'open_boutdataset() called with geometry="{geometry}", but we are '
f"reloading a Dataset that was saved after being loaded with "
f'geometry="{ds.attrs.get("geometry", None)}". Applying '
f'geometry="{geometry}" from the argument.'
)
if gridfilepath is not None:
grid = _open_grid(
gridfilepath,
chunks=chunks,
keep_xboundaries=keep_xboundaries,
keep_yboundaries=keep_yboundaries,
mxg=ds.metadata["MXG"],
)
else:
grid = None
ds = geometries.apply_geometry(ds, geometry, grid=grid)
elif "geometry" in ds.attrs:
ds = geometries.apply_geometry(ds, ds.attrs["geometry"])
else:
ds = geometries.apply_geometry(ds, None)
Expand Down
35 changes: 23 additions & 12 deletions xbout/tests/test_boutdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1378,20 +1378,12 @@ def test_save_all(self, tmpdir_factory, bout_xyt_example_files):

@pytest.mark.parametrize("geometry", [None, "toroidal"])
def test_reload_all(self, tmpdir_factory, bout_xyt_example_files, geometry):
if geometry is not None:
grid = "grid"
else:
grid = None

# Create data
path = bout_xyt_example_files(
tmpdir_factory, nxpe=4, nype=5, nt=1, grid=grid, write_to_disk=True
tmpdir_factory, nxpe=4, nype=5, nt=1, grid="grid", write_to_disk=True
)

if grid is not None:
gridpath = str(Path(path).parent) + "/grid.nc"
else:
gridpath = None
gridpath = str(Path(path).parent) + "/grid.nc"

# Load it as a boutdataset
if geometry is None:
Expand All @@ -1400,14 +1392,14 @@ def test_reload_all(self, tmpdir_factory, bout_xyt_example_files, geometry):
datapath=path,
inputfilepath=None,
geometry=geometry,
gridfilepath=gridpath,
gridfilepath=None if geometry is None else gridpath,
)
else:
original = open_boutdataset(
datapath=path,
inputfilepath=None,
geometry=geometry,
gridfilepath=gridpath,
gridfilepath=None if geometry is None else gridpath,
)

# Save it to a netCDF file
Expand All @@ -1419,6 +1411,25 @@ def test_reload_all(self, tmpdir_factory, bout_xyt_example_files, geometry):

xrt.assert_identical(original.load(), recovered.load())

# Check if we can load with a different geometry argument
for reload_geometry in [None, "toroidal"]:
if reload_geometry is None or geometry == reload_geometry:
recovered = open_boutdataset(
savepath,
geometry=reload_geometry,
gridfilepath=None if reload_geometry is None else gridpath,
)
xrt.assert_identical(original.load(), recovered.load())
else:
# Expect a warning because we change the geometry
print("here", gridpath)
with pytest.warns(UserWarning):
recovered = open_boutdataset(
savepath, geometry=reload_geometry, gridfilepath=gridpath
)
# Datasets won't be exactly the same because different geometry was
# applied

@pytest.mark.parametrize("save_dtype", [np.float64, np.float32])
@pytest.mark.parametrize(
"separate_vars", [False, pytest.param(True, marks=pytest.mark.long)]
Expand Down

0 comments on commit 98fbb5d

Please sign in to comment.