Skip to content

Commit

Permalink
refactor code style in DCA
Browse files Browse the repository at this point in the history
  • Loading branch information
maybenotilya committed Sep 26, 2024
1 parent ea79ffc commit 673b80f
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 46 deletions.
72 changes: 39 additions & 33 deletions DCA/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,7 @@ class DcaAdapter(AdapterBase):
Adapter to work with DCA algorithm
"""

def __init__(
self,
factor: int,
model: Path,
device: str
):
def __init__(self, factor: int, model: Path, device: str):
"""
Args:
factor (int): factor to divide image into patches
Expand All @@ -57,37 +52,39 @@ def process(self, image: np.ndarray):
def _transform_image(self, image: np.ndarray):
mean = np.mean(image, axis=(0, 1))
std = np.std(image, axis=(0, 1))
transformer = Compose([
Normalize(mean=mean,
std=std,
max_pixel_value=1, always_apply=True),
ever.preprocess.albu.ToTensor()
], is_check_shapes=False)
transformer = Compose(
[
Normalize(mean=mean, std=std, max_pixel_value=1, always_apply=True),
ever.preprocess.albu.ToTensor(),
],
is_check_shapes=False,
)
blob = transformer(image=image)
image = blob["image"].to(self._device)
image = image[None, :]
return image

def _build_model(self):
model = Deeplabv2(dict(
backbone=dict(
resnet_type="resnet50",
output_stride=16,
pretrained=True,
),
multi_layer=True,
cascade=False,
use_ppm=True,
ppm=dict(
model = Deeplabv2(
dict(
backbone=dict(
resnet_type="resnet50",
output_stride=16,
pretrained=True,
),
multi_layer=True,
cascade=False,
use_ppm=True,
ppm=dict(
num_classes=7,
use_aux=False,
fc_dim=2048,
),
inchannels=2048,
num_classes=7,
use_aux=False,
fc_dim=2048,
),
inchannels=2048,
num_classes=7
)).to(self._device)
model_state_dict = torch.load(self._model,
map_location=self._device)
)
).to(self._device)
model_state_dict = torch.load(self._model, map_location=self._device)
model.load_state_dict(model_state_dict, strict=True)
model.eval()
return model
Expand All @@ -99,11 +96,20 @@ def _process(self, model, image):
model=model,
image=image,
num_classes=7,
tile_size=(image.shape[2] // self._factor, image.shape[3] // self._factor),
tile_size=(
image.shape[2] // self._factor,
image.shape[3] // self._factor,
),
tta=True,
device=self._device
device=self._device,
)
cls = (
cls.argmax(dim=1)
.to(self._device)
.numpy()
.reshape(shape)
.astype(np.uint8)
)
cls = cls.argmax(dim=1).to(self._device).numpy().reshape(shape).astype(np.uint8)
return cls

def _postprocess_predictions(self, raw_predictions):
Expand Down
4 changes: 2 additions & 2 deletions DCA/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,6 @@
image = imread(file_path)
logging.info(f"Image shape: {image.shape}")
mask = adapter.process(image)
output_path = output_dir / Path(file_path.stem).with_suffix('.npy')
output_path = output_dir / Path(file_path.stem).with_suffix(".npy")
logging.info(f" Saving to {output_path}")
np.save(output_path, mask)
np.save(output_path, mask)
47 changes: 36 additions & 11 deletions DCA/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,40 @@

def get_args():
parser = ArgumentParser()
parser.add_argument("-f", "--factor", type=int, default=2,
help="Factor shows how images must be scaled to create patches, for factor = n there will be "
"n^2 patches (default: 2)")
parser.add_argument("-m", "--model", type=Path, default=Path("/", "DCA", "weights", "Urban.pth"),
help="Pretrained model path (default: weights/Urban.pth)")
parser.add_argument("-d", "--device", type=str, default="cuda",
help="Which device to run network on (default: GPU)")
parser.add_argument("-i", "--input", type=Path, default=Path("/", "DCA", "input"),
help="Images input folder")
parser.add_argument("-o", "--output", type=Path, default=Path("/", "DCA", "output"),
help="Masks output folder")
parser.add_argument(
"-f",
"--factor",
type=int,
default=2,
help="Factor shows how images must be scaled to create patches, for factor = n there will be "
"n^2 patches (default: 2)",
)
parser.add_argument(
"-m",
"--model",
type=Path,
default=Path("/", "DCA", "weights", "Urban.pth"),
help="Pretrained model path (default: weights/Urban.pth)",
)
parser.add_argument(
"-d",
"--device",
type=str,
default="cuda",
help="Which device to run network on (default: GPU)",
)
parser.add_argument(
"-i",
"--input",
type=Path,
default=Path("/", "DCA", "input"),
help="Images input folder",
)
parser.add_argument(
"-o",
"--output",
type=Path,
default=Path("/", "DCA", "output"),
help="Masks output folder",
)
return parser.parse_args()

0 comments on commit 673b80f

Please sign in to comment.