Skip to content

Commit

Permalink
fix check
Browse files Browse the repository at this point in the history
  • Loading branch information
haruishi43 committed Sep 8, 2021
1 parent 9aa8aa8 commit 0c082c4
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
4 changes: 3 additions & 1 deletion equilib/equi2equi/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,12 @@ def equi2equi(
else:
raise ValueError

is_single = False
if len(src.shape) == 3 and isinstance(rots, dict):
# probably the input was a single image
src = src[None, ...]
rots = [rots]
is_single = True
elif len(src.shape) == 3:
# probably a grayscale image
src = src[:, None, ...]
Expand Down Expand Up @@ -121,7 +123,7 @@ def equi2equi(
raise ValueError

# make sure that the output batch dim is removed if it's only a single image
if out.shape[0] == 1:
if is_single:
out = out.squeeze(0)

return out
4 changes: 3 additions & 1 deletion equilib/equi2pers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,12 @@ def equi2pers(
else:
raise ValueError

is_single = False
if len(equi.shape) == 3 and isinstance(rots, dict):
# probably the input was a single image
equi = equi[None, ...]
rots = [rots]
is_single = True
elif len(equi.shape) == 3:
# probably a grayscale image
equi = equi[:, None, ...]
Expand Down Expand Up @@ -144,7 +146,7 @@ def equi2pers(
raise ValueError

# make sure that the output batch dim is removed if it's only a single image
if out.shape[0] == 1:
if is_single:
out = out.squeeze(0)

return out

0 comments on commit 0c082c4

Please sign in to comment.