From 0c082c43a49fbe35a9c5a184a39cefeab0f0b3f6 Mon Sep 17 00:00:00 2001 From: Toraudonn Date: Wed, 8 Sep 2021 06:43:24 +0000 Subject: [PATCH] fix check --- equilib/equi2equi/base.py | 4 +++- equilib/equi2pers/base.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/equilib/equi2equi/base.py b/equilib/equi2equi/base.py index 68275fe3..29ea864c 100644 --- a/equilib/equi2equi/base.py +++ b/equilib/equi2equi/base.py @@ -88,10 +88,12 @@ def equi2equi( else: raise ValueError + is_single = False if len(src.shape) == 3 and isinstance(rots, dict): # probably the input was a single image src = src[None, ...] rots = [rots] + is_single = True elif len(src.shape) == 3: # probably a grayscale image src = src[:, None, ...] @@ -121,7 +123,7 @@ def equi2equi( raise ValueError # make sure that the output batch dim is removed if it's only a single image - if out.shape[0] == 1: + if is_single: out = out.squeeze(0) return out diff --git a/equilib/equi2pers/base.py b/equilib/equi2pers/base.py index 60e678b8..f18f13e3 100644 --- a/equilib/equi2pers/base.py +++ b/equilib/equi2pers/base.py @@ -107,10 +107,12 @@ def equi2pers( else: raise ValueError + is_single = False if len(equi.shape) == 3 and isinstance(rots, dict): # probably the input was a single image equi = equi[None, ...] rots = [rots] + is_single = True elif len(equi.shape) == 3: # probably a grayscale image equi = equi[:, None, ...] @@ -144,7 +146,7 @@ def equi2pers( raise ValueError # make sure that the output batch dim is removed if it's only a single image - if out.shape[0] == 1: + if is_single: out = out.squeeze(0) return out