Skip to content

Commit

Permalink
Converted utils.autoconvert and utils.get_subplot_dimension into inte…
Browse files Browse the repository at this point in the history
…rnal functions
  • Loading branch information
wehs7661 committed Aug 29, 2023
1 parent 27a8e24 commit 006b608
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 19 deletions.
4 changes: 2 additions & 2 deletions ensemble_md/analysis/analyze_traj.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def plot_state_trajs(trajs, state_ranges, fig_name, dt=None, stride=1, title_pre
x_range = [np.min(x), np.max(x)]
y_range = [-0.2, np.max(trajs) + 0.2]
n_configs = len(trajs)
n_rows, n_cols = utils.get_subplot_dimension(n_configs)
n_rows, n_cols = utils._get_subplot_dimension(n_configs)
_, ax = plt.subplots(nrows=n_rows, ncols=n_cols, figsize=(2.5 * n_cols, 2.5 * n_rows))
for i in range(n_configs):
plt.subplot(n_rows, n_cols, i + 1)
Expand Down Expand Up @@ -537,7 +537,7 @@ def plot_state_hist(trajs, state_ranges, fig_name, stack=True, figsize=None, pre
plt.tight_layout()
plt.savefig(f'{fig_name}', dpi=600)
else:
n_rows, n_cols = utils.get_subplot_dimension(n_configs)
n_rows, n_cols = utils._get_subplot_dimension(n_configs)
_, ax = plt.subplots(nrows=n_rows, ncols=n_cols, figsize=(4 * n_cols, 3 * n_rows))
for i in range(n_configs):
plt.subplot(n_rows, n_cols, i + 1)
Expand Down
28 changes: 14 additions & 14 deletions ensemble_md/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,31 +55,31 @@ def test_format_time():

def test_autoconvert():
# Test non-string input
assert utils.autoconvert(42) == 42
assert utils._autoconvert(42) == 42

# Test string input that can be converted to int
assert utils.autoconvert("42") == 42
assert utils._autoconvert("42") == 42

# Test string input that can be converted to float
assert utils.autoconvert("3.14159") == 3.14159
assert utils._autoconvert("3.14159") == 3.14159

# Test string input that can be converted to a numpy array of ints
assert utils.autoconvert("1 2 3") == [1, 2, 3]
assert utils._autoconvert("1 2 3") == [1, 2, 3]

# Test string input that can be converted to a numpy array of floats
assert utils.autoconvert("1.0 2.0 3.0") == [1.0, 2.0, 3.0]
assert utils._autoconvert("1.0 2.0 3.0") == [1.0, 2.0, 3.0]


def test_get_subplot_dimension():
assert utils.get_subplot_dimension(1) == (1, 1)
assert utils.get_subplot_dimension(2) == (1, 2)
assert utils.get_subplot_dimension(3) == (2, 2)
assert utils.get_subplot_dimension(4) == (2, 2)
assert utils.get_subplot_dimension(5) == (2, 3)
assert utils.get_subplot_dimension(6) == (2, 3)
assert utils.get_subplot_dimension(7) == (3, 3)
assert utils.get_subplot_dimension(8) == (3, 3)
assert utils.get_subplot_dimension(9) == (3, 3)
assert utils._get_subplot_dimension(1) == (1, 1)
assert utils._get_subplot_dimension(2) == (1, 2)
assert utils._get_subplot_dimension(3) == (2, 2)
assert utils._get_subplot_dimension(4) == (2, 2)
assert utils._get_subplot_dimension(5) == (2, 3)
assert utils._get_subplot_dimension(6) == (2, 3)
assert utils._get_subplot_dimension(7) == (3, 3)
assert utils._get_subplot_dimension(8) == (3, 3)
assert utils._get_subplot_dimension(9) == (3, 3)


def test_weighted_mean():
Expand Down
2 changes: 1 addition & 1 deletion ensemble_md/utils/gmx_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def __eq__(self, other):

def _transform(self, value):
if self.autoconvert:
return utils.autoconvert(value)
return utils._autoconvert(value)
else:
return value.rstrip()

Expand Down
4 changes: 2 additions & 2 deletions ensemble_md/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def format_time(t):
return t_str


def autoconvert(s):
def _autoconvert(s):
"""
Converts input to a numerical type if possible. Used for the MDP parser.
Modified from `utilities.py in GromacsWrapper <https://github.com/Becksteinlab/GromacsWrapper>`_.
Expand Down Expand Up @@ -179,7 +179,7 @@ def autoconvert(s):
raise ValueError("Failed to autoconvert {0!r}".format(s))


def get_subplot_dimension(n_panels):
def _get_subplot_dimension(n_panels):
"""
Gets the numbers of rows and columns in a subplot such that
the arrangement of the .
Expand Down

0 comments on commit 006b608

Please sign in to comment.