Skip to content

Commit

Permalink
add function to check for existence of gzipped file during test
Browse files Browse the repository at this point in the history
  • Loading branch information
esoteric-ephemera committed May 22, 2024
1 parent 6aaecf4 commit 101939c
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 7 deletions.
23 changes: 23 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os

import pytest
from shutil import copyfileobj

TEST_DIR = os.path.dirname(__file__)
TEST_FILES = f"{TEST_DIR}/files"
Expand All @@ -13,3 +14,25 @@
def _patch_get_potential_energy(monkeypatch) -> None:
"""Monkeypatch the multiprocessing.cpu_count() function to always return 64."""
monkeypatch.setattr(multiprocessing, "cpu_count", lambda: 64)

def get_gzip_or_unzipped(test_file_or_dir : str) -> str:
"""
Return the file or its unzipped version, depending on which one exists.
Running pytest in CI seems to unzip the test files prior to testing.
To get around this behavior, we return the whichever file exists,
its gzipped version or the unzipped version.
Args:
test_file_or_dir (str) : the name of the test file or directory
Returns:
The file with or without a .gz/.GZ extension if any exist, or the
unmodified path to the directory.
"""
if os.path.isdir(test_file_or_dir):
return test_file_or_dir

for file_to_test in [test_file_or_dir, test_file_or_dir.split(".gz")[0], test_file_or_dir.split(".GZ")[0]]:
if os.path.isfile(file_to_test):
return file_to_test
raise FileNotFoundError(f"Cannot find {test_file_or_dir}")
4 changes: 2 additions & 2 deletions tests/vasp/test_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
VaspErrorHandler,
WalltimeHandler,
)
from tests.conftest import TEST_FILES
from tests.conftest import TEST_FILES, get_gzip_or_unzipped

__author__ = "Shyue Ping Ong, Stephen Dacek, Janosh Riebesell"
__copyright__ = "Copyright 2012, The Materials Project"
Expand All @@ -51,7 +51,7 @@ def _clear_tracked_cache() -> None:

def copy_tmp_files(tmp_path: str, *file_paths: str) -> None:
for file_path in file_paths:
src_path = f"{TEST_FILES}/{file_path}"
src_path = get_gzip_or_unzipped(f"{TEST_FILES}/{file_path}")
dst_path = f"{tmp_path}/{os.path.basename(file_path)}"
if os.path.isdir(src_path):
shutil.copytree(src_path, dst_path)
Expand Down
12 changes: 7 additions & 5 deletions tests/vasp/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from custodian.utils import tracked_lru_cache
from custodian.vasp.io import load_outcar, load_vasprun
from tests.conftest import TEST_FILES
from tests.conftest import TEST_FILES, get_gzip_or_unzipped


@pytest.fixture(autouse=True)
Expand All @@ -15,18 +15,20 @@ def _clear_tracked_cache() -> None:

class TestIO:
def test_load_outcar(self) -> None:
outcar = load_outcar(f"{TEST_FILES}/io/OUTCAR.gz")
outcar_file = get_gzip_or_unzipped(f"{TEST_FILES}/io/OUTCAR.gz")
outcar = load_outcar(outcar_file)
assert outcar is not None
outcar2 = load_outcar(f"{TEST_FILES}/io/OUTCAR.gz")
outcar2 = load_outcar(outcar_file)

assert outcar is outcar2

assert len(tracked_lru_cache.cached_functions) == 1

def test_load_vasprun(self) -> None:
vr = load_vasprun(f"{TEST_FILES}/io/vasprun.xml.gz")
vasprun_file = get_gzip_or_unzipped(f"{TEST_FILES}/io/vasprun.xml.gz")
vr = load_vasprun(vasprun_file)
assert vr is not None
vr2 = load_vasprun(f"{TEST_FILES}/io/vasprun.xml.gz")
vr2 = load_vasprun(vasprun_file)

assert vr is vr2

Expand Down

0 comments on commit 101939c

Please sign in to comment.