diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ad8a148 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +__pycache__/ +train.py \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..e69de29 diff --git a/demo.py b/demo.py new file mode 100644 index 0000000..1579e7e --- /dev/null +++ b/demo.py @@ -0,0 +1,31 @@ +import torch +from PIL import Image +from torchvision import transforms +from model import CheckerboardAutogressive + +torch.backends.cudnn.deterministic = True + +if __name__ == '__main__': + device = 'cuda' if torch.cuda.is_available() else 'cpu' + checkpoint = torch.load('checkpoint.pth.tar', map_location=device) + + net = CheckerboardAutogressive().to(device).eval() + net.load_state_dict(checkpoint["state_dict"]) + + img = Image.open('./images/kodim01.png').convert('RGB') + x = transforms.ToTensor()(img).unsqueeze(0).to(device) + + with torch.no_grad(): + # codec + out = net.compress(x) + rec = net.decompress(out['strings'], out['shape']) + rec = transforms.ToPILImage()(rec['x_hat'].squeeze().cpu()) + rec.save('./images/codec.png', format="PNG") + + # inference + out = net(x) + rec = out['x_hat'].clamp(0, 1) + rec = transforms.ToPILImage()(rec.squeeze().cpu()) + rec.save('./images/infer.png', format="PNG") + + print('saved in ./images') diff --git a/images/codec.png b/images/codec.png new file mode 100644 index 0000000..4836e07 Binary files /dev/null and b/images/codec.png differ diff --git a/images/infer.png b/images/infer.png new file mode 100644 index 0000000..2d2a92b Binary files /dev/null and b/images/infer.png differ diff --git a/images/kodim01.png b/images/kodim01.png new file mode 100644 index 0000000..14317f0 Binary files /dev/null and b/images/kodim01.png differ diff --git a/layers/__init__.py b/layers/__init__.py new file mode 100644 index 0000000..09ad170 --- /dev/null +++ b/layers/__init__.py @@ -0,0 +1 @@ +from .layers import * \ No newline at end of file diff --git a/layers/layers.py b/layers/layers.py new file mode 100644 index 0000000..cabd4c5 --- /dev/null +++ b/layers/layers.py @@ -0,0 +1,41 @@ +from typing import Any + +import torch +import torch.nn as nn + +from torch import Tensor + + +class CheckerboardMaskedConv2d(nn.Conv2d): + """ + if kernel_size == (5, 5) + then mask: + [[0., 1., 0., 1., 0.], + [1., 0., 1., 0., 1.], + [0., 1., 0., 1., 0.], + [1., 0., 1., 0., 1.], + [0., 1., 0., 1., 0.]] + 0: non-anchor + 1: anchor + """ + def __init__(self, *args: Any, **kwargs: Any): + super().__init__(*args, **kwargs) + + self.register_buffer("mask", torch.zeros_like(self.weight.data)) + + self.mask[:, :, 0::2, 1::2] = 1 + self.mask[:, :, 1::2, 0::2] = 1 + + def forward(self, x: Tensor) -> Tensor: + # TODO: weight assigment is not supported by torchscript + self.weight.data *= self.mask + return super().forward(x) + + +if __name__ == '__main__': + + # notice that the bias is 'True' in practice + ckbd = CheckerboardMaskedConv2d(3, 3, kernel_size=5, padding=2, stride=1, bias=True) + x = torch.rand((1, 3, 8, 8)) + + print(ckbd(x)) \ No newline at end of file diff --git a/model.py b/model.py new file mode 100644 index 0000000..159f605 --- /dev/null +++ b/model.py @@ -0,0 +1,150 @@ +import torch + +from compressai.models.google import JointAutoregressiveHierarchicalPriors +from layers import CheckerboardMaskedConv2d +from modules import Demultiplexer, Multiplexer + +class CheckerboardAutogressive(JointAutoregressiveHierarchicalPriors): + def __init__(self, N=192, M=192, **kwargs): + super().__init__(N, M, **kwargs) + + self.context_prediction = CheckerboardMaskedConv2d( + M, 2 * M, kernel_size=5, padding=2, stride=1 + ) + + def forward(self, x): + y = self.g_a(x) + z = self.h_a(y) + z_hat, z_likelihoods = self.entropy_bottleneck(z) + params = self.h_s(z_hat) + + y_hat = self.gaussian_conditional.quantize( + y, "noise" if self.training else "dequantize" + ) + + # set non_anchor to 0 + y_half = y_hat.clone() + y_half[:, :, 0::2, 0::2] = 0 + y_half[:, :, 1::2, 1::2] = 0 + + # set anchor's ctx to 0, otherwise there will be a bias + ctx_params = self.context_prediction(y_half) + ctx_params[:, :, 0::2, 1::2] = 0 + ctx_params[:, :, 1::2, 0::2] = 0 + + gaussian_params = self.entropy_parameters( + torch.cat((params, ctx_params), dim=1) + ) + scales_hat, means_hat = gaussian_params.chunk(2, 1) + _, y_likelihoods = self.gaussian_conditional(y, scales_hat, means=means_hat) + x_hat = self.g_s(y_hat) + + return { + "x_hat": x_hat, + "likelihoods": {"y": y_likelihoods, "z": z_likelihoods}, + } + + def compress(self, x): + y = self.g_a(x) + z = self.h_a(y) + + z_strings = self.entropy_bottleneck.compress(z) + z_hat = self.entropy_bottleneck.decompress(z_strings, z.size()[-2:]) + + params = self.h_s(z_hat) + + # Notion: in compressai, the means must be subtracted before quantification. + # In order to get y_half, we need subtract y_anchor's means and then quantize, + # to get y_anchor's means, we have to go through 'gep' here + N, _, H, W = z_hat.shape + zero_ctx_params = torch.zeros([N, 2 * self.M, H * 4, W * 4]).to(z_hat.device) + gaussian_params = self.entropy_parameters( + torch.cat((params, zero_ctx_params), dim=1) + ) + _, means_hat = gaussian_params.chunk(2, 1) + y_hat = self.gaussian_conditional.quantize(y, "dequantize", means=means_hat) + + # set non_anchor to 0 + y_half = y_hat.clone() + y_half[:, :, 0::2, 0::2] = 0 + y_half[:, :, 1::2, 1::2] = 0 + + # set anchor's ctx to 0, otherwise there will be a bias + ctx_params = self.context_prediction(y_half) + ctx_params[:, :, 0::2, 1::2] = 0 + ctx_params[:, :, 1::2, 0::2] = 0 + + gaussian_params = self.entropy_parameters( + torch.cat((params, ctx_params), dim=1) + ) + + scales_hat, means_hat = gaussian_params.chunk(2, 1) + + y_anchor, y_non_anchor = Demultiplexer(y) + scales_hat_anchor, scales_hat_non_anchor = Demultiplexer(scales_hat) + means_hat_anchor, means_hat_non_anchor = Demultiplexer(means_hat) + + indexes_anchor = self.gaussian_conditional.build_indexes(scales_hat_anchor) + indexes_non_anchor = self.gaussian_conditional.build_indexes(scales_hat_non_anchor) + + anchor_strings = self.gaussian_conditional.compress(y_anchor, indexes_anchor, means=means_hat_anchor) + non_anchor_strings = self.gaussian_conditional.compress(y_non_anchor, indexes_non_anchor, means=means_hat_non_anchor) + + return { + "strings": [anchor_strings, non_anchor_strings, z_strings], + "shape": z.size()[-2:], + } + + def decompress(self, strings, shape): + """ + See Figure 5. Illustration of the proposed two-pass decoding. + """ + assert isinstance(strings, list) and len(strings) == 3 + z_hat = self.entropy_bottleneck.decompress(strings[2], shape) + params = self.h_s(z_hat) + + # PASS 1: anchor + N, _, H, W = z_hat.shape + zero_ctx_params = torch.zeros([N, 2 * self.M, H * 4, W * 4]).to(z_hat.device) + gaussian_params = self.entropy_parameters( + torch.cat((params, zero_ctx_params), dim=1) + ) + + scales_hat, means_hat = gaussian_params.chunk(2, 1) + scales_hat_anchor, _ = Demultiplexer(scales_hat) + means_hat_anchor, _ = Demultiplexer(means_hat) + + indexes_anchor = self.gaussian_conditional.build_indexes(scales_hat_anchor) + y_anchor = self.gaussian_conditional.decompress(strings[0], indexes_anchor, means=means_hat_anchor) # [1, 384, 8, 8] + y_anchor = Multiplexer(y_anchor, torch.zeros_like(y_anchor)) # [1, 192, 16, 16] + + # PASS 2: non-anchor + ctx_params = self.context_prediction(y_anchor) + gaussian_params = self.entropy_parameters( + torch.cat((params, ctx_params), dim=1) + ) + + scales_hat, means_hat = gaussian_params.chunk(2, 1) + _, scales_hat_non_anchor = Demultiplexer(scales_hat) + _, means_hat_non_anchor = Demultiplexer(means_hat) + + indexes_non_anchor = self.gaussian_conditional.build_indexes(scales_hat_non_anchor) + y_non_anchor = self.gaussian_conditional.decompress(strings[1], indexes_non_anchor, means=means_hat_non_anchor) # [1, 384, 8, 8] + y_non_anchor = Multiplexer(torch.zeros_like(y_non_anchor), y_non_anchor) # [1, 192, 16, 16] + + # gather + y_hat = y_anchor + y_non_anchor + x_hat = self.g_s(y_hat).clamp_(0, 1) + + return { + "x_hat": x_hat, + } + + +if __name__ == "__main__": + x = torch.randn([1, 3, 256, 256]) + model = CheckerboardAutogressive() + model.update(force=True) + + out = model.compress(x) + rec = model.decompress(out["strings"], out["shape"]) diff --git a/modules/__init__.py b/modules/__init__.py new file mode 100644 index 0000000..9a8067b --- /dev/null +++ b/modules/__init__.py @@ -0,0 +1 @@ +from .modules import * \ No newline at end of file diff --git a/modules/modules.py b/modules/modules.py new file mode 100644 index 0000000..9c0f849 --- /dev/null +++ b/modules/modules.py @@ -0,0 +1,79 @@ +from turtle import forward +import torch.nn as nn +import torch + + +class Space2Depth(nn.Module): + """ + ref: https://github.com/huzi96/Coarse2Fine-PyTorch/blob/master/networks.py + """ + + def __init__(self, r=2): + super().__init__() + self.r = r + + def forward(self, x): + r = self.r + b, c, h, w = x.size() + out_c = c * (r**2) + out_h = h // r + out_w = w // r + x_view = x.view(b, c, out_h, r, out_w, r) + x_prime = x_view.permute(0, 3, 5, 1, 2, 4).contiguous().view(b, out_c, out_h, out_w) + return x_prime + + +class Depth2Space(nn.Module): + def __init__(self, r=2): + super().__init__() + self.r = r + + def forward(self, x): + r = self.r + b, c, h, w = x.size() + out_c = c // (r**2) + out_h = h * r + out_w = w * r + x_view = x.view(b, r, r, out_c, h, w) + x_prime = x_view.permute(0, 3, 4, 1, 5, 2).contiguous().view(b, out_c, out_h, out_w) + return x_prime + + +def Demultiplexer(x): + """ + See Supplementary Material: Figure 2 + """ + x_prime = Space2Depth(r=2)(x) + + _, C, _, _ = x_prime.shape + anchor_index = tuple(range(C // 4, C * 3 // 4)) + non_anchor_index = tuple(range(0, C // 4)) + tuple(range(C * 3 // 4, C)) + + anchor = x_prime[:, anchor_index, :, :] + non_anchor = x_prime[:, non_anchor_index, :, :] + + return anchor, non_anchor + +def Multiplexer(anchor, non_anchor): + """ + The inverse opperation of Demultiplexer + """ + _, C, _, _ = non_anchor.shape + x_prime = torch.cat((non_anchor[:, : C//2, :, :], anchor, non_anchor[:, C//2:, :, :]), dim=1) + return Depth2Space(r=2)(x_prime) + + +if __name__ == '__main__': + x = torch.zeros(1, 1, 6, 6) + x[0, 0, 0, 0] = 0 + x[0, 0, 0, 1] = 1 + x[0, 0, 1, 0] = 2 + x[0, 0, 1, 1] = 3 + print(x) + + anchor, non_anchor = Demultiplexer(x) + print(anchor) + print(non_anchor) + + x = Multiplexer(anchor, non_anchor) + print(x)