Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added code for training with bgr #254

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions dataset/videomatte.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@
class VideoMatteDataset(Dataset):
def __init__(self,
videomatte_dir,
background_image_dir,
# background_image_dir,
background_video_dir,
size,
seq_length,
seq_sampler,
transform=None):
self.background_image_dir = background_image_dir
self.background_image_files = os.listdir(background_image_dir)
# self.background_image_dir = background_image_dir
# self.background_image_files = os.listdir(background_image_dir)
self.background_video_dir = background_video_dir
self.background_video_clips = sorted(os.listdir(background_video_dir))
self.background_video_frames = [sorted(os.listdir(os.path.join(background_video_dir, clip)))
Expand All @@ -38,10 +38,10 @@ def __len__(self):
return len(self.videomatte_idx)

def __getitem__(self, idx):
if random.random() < 0.5:
bgrs = self._get_random_image_background()
else:
bgrs = self._get_random_video_background()
# if random.random() < 0.5:
# bgrs = self._get_random_image_background()
# else:
bgrs = self._get_random_video_background()

fgrs, phas = self._get_videomatte(idx)

Expand All @@ -50,11 +50,11 @@ def __getitem__(self, idx):

return fgrs, phas, bgrs

def _get_random_image_background(self):
with Image.open(os.path.join(self.background_image_dir, random.choice(self.background_image_files))) as bgr:
bgr = self._downsample_if_needed(bgr.convert('RGB'))
bgrs = [bgr] * self.seq_length
return bgrs
# def _get_random_image_background(self):
# with Image.open(os.path.join(self.background_image_dir, random.choice(self.background_image_files))) as bgr:
# bgr = self._downsample_if_needed(bgr.convert('RGB'))
# bgrs = [bgr] * self.seq_length
# return bgrs

def _get_random_video_background(self):
clip_idx = random.choice(range(len(self.background_video_clips)))
Expand Down
4 changes: 4 additions & 0 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,10 @@ def convert_video(model,
rec = [None] * 4
for src in reader:

if src.shape[-1] %2 == 1:
src = src[:, :, :, :-1]
if src.shape[-2] %2 == 1:
src = src[:, :, :-1, :]
if downsample_ratio is None:
downsample_ratio = auto_downsample_ratio(*src.shape[2:])

Expand Down
19 changes: 12 additions & 7 deletions inference_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch.utils.data import Dataset
from torchvision.transforms.functional import to_pil_image
from PIL import Image

import torch

class VideoReader(Dataset):
def __init__(self, path, transform=None):
Expand Down Expand Up @@ -55,18 +55,23 @@ def close(self):
class ImageSequenceReader(Dataset):
def __init__(self, path, transform=None):
self.path = path
self.files = sorted(os.listdir(path))
self.files_fgr = sorted(os.listdir(path + "fgr/"))
self.files_bgr = sorted(os.listdir(path + "bgr/"))
self.transform = transform

def __len__(self):
return len(self.files)
return len(self.files_fgr)

def __getitem__(self, idx):
with Image.open(os.path.join(self.path, self.files[idx])) as img:
img.load()
with Image.open(os.path.join(self.path + "fgr/", self.files_fgr[idx])) as fgr_img:
fgr_img.load()

with Image.open(os.path.join(self.path + "bgr/", self.files_bgr[idx])) as bgr_img:
bgr_img.load()

if self.transform is not None:
return self.transform(img)
return img
return torch.cat([self.transform(fgr_img), self.transform(bgr_img)], dim = 0)
return fgr_img


class ImageSequenceWriter:
Expand Down
10 changes: 5 additions & 5 deletions model/decoder.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
import torch
from torch import Tensor
from torch import nn
from torch.nn import functional as F
# from torch.nn import functional as F
from typing import Tuple, Optional

class RecurrentDecoder(nn.Module):
def __init__(self, feature_channels, decoder_channels):
super().__init__()
self.avgpool = AvgPool()
self.decode4 = BottleneckBlock(feature_channels[3])
self.decode3 = UpsamplingBlock(feature_channels[3], feature_channels[2], 3, decoder_channels[0])
self.decode2 = UpsamplingBlock(decoder_channels[0], feature_channels[1], 3, decoder_channels[1])
self.decode1 = UpsamplingBlock(decoder_channels[1], feature_channels[0], 3, decoder_channels[2])
self.decode0 = OutputBlock(decoder_channels[2], 3, decoder_channels[3])
self.decode3 = UpsamplingBlock(feature_channels[3], feature_channels[2], 6, decoder_channels[0])
self.decode2 = UpsamplingBlock(decoder_channels[0], feature_channels[1], 6, decoder_channels[1])
self.decode1 = UpsamplingBlock(decoder_channels[1], feature_channels[0], 6, decoder_channels[2])
self.decode0 = OutputBlock(decoder_channels[2], 6, decoder_channels[3])

def forward(self,
s0: Tensor, f1: Tensor, f2: Tensor, f3: Tensor, f4: Tensor,
Expand Down
31 changes: 28 additions & 3 deletions model/mobilenetv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,21 @@
from torchvision.models.mobilenetv3 import MobileNetV3, InvertedResidualConfig
from torchvision.transforms.functional import normalize

def load_matched_state_dict(model, state_dict, print_stats=True):
"""
Only loads weights that matched in key and shape. Ignore other weights.
"""
num_matched, num_total = 0, 0
curr_state_dict = model.state_dict()
for key in curr_state_dict.keys():
num_total += 1
if key in state_dict and curr_state_dict[key].shape == state_dict[key].shape:
curr_state_dict[key] = state_dict[key]
num_matched += 1
model.load_state_dict(curr_state_dict)
if print_stats:
print(f'Loaded state_dict: {num_matched}/{num_total} matched')

class MobileNetV3LargeEncoder(MobileNetV3):
def __init__(self, pretrained: bool = False):
super().__init__(
Expand All @@ -27,14 +42,24 @@ def __init__(self, pretrained: bool = False):
)

if pretrained:
self.load_state_dict(torch.hub.load_state_dict_from_url(
'https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth'))
pretrained_state_dict = torch.hub.load_state_dict_from_url(
'https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth')

# print("pretrained_state_dict keys \n \n ", pretrained_state_dict.keys())

# print("\n\ncurrent model state dict keys \n\n", self.state_dict().keys())

load_matched_state_dict(self, pretrained_state_dict)

# self.load_state_dict(torch.hub.load_state_dict_from_url(
# 'https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth'))

del self.avgpool
del self.classifier

def forward_single_frame(self, x):
x = normalize(x, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
# print(x.shape)
x = torch.cat((normalize(x[:, :3, ...], [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), normalize(x[:, 3:, ...], [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])), dim = -3)

x = self.features[0](x)
x = self.features[1](x)
Expand Down
5 changes: 3 additions & 2 deletions model/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
from torch import Tensor
from torch import nn
from torchsummary import summary
from torch.nn import functional as F
from typing import Optional, List

Expand Down Expand Up @@ -58,8 +59,8 @@ def forward(self,
if not segmentation_pass:
fgr_residual, pha = self.project_mat(hid).split([3, 1], dim=-3)
if downsample_ratio != 1:
fgr_residual, pha = self.refiner(src, src_sm, fgr_residual, pha, hid)
fgr = fgr_residual + src
fgr_residual, pha = self.refiner(src[:, :, :3, ...], src_sm[:, :, :3, ...], fgr_residual, pha, hid)
fgr = fgr_residual + src[:, :, :3, ...]
fgr = fgr.clamp(0., 1.)
pha = pha.clamp(0., 1.)
return [fgr, pha, *rec]
Expand Down
10 changes: 6 additions & 4 deletions requirements_training.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
easing_functions==1.0.4
tensorboard==2.5.0
torch==1.9.0
torchvision==0.10.0
tqdm==4.61.1
tensorboard
torch
torchvision
tqdm==4.61.1
opencv-python==4.6.0.66
torchsummary
Loading