Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
leelitian committed Mar 14, 2022
0 parents commit a8e3ecc
Show file tree
Hide file tree
Showing 11 changed files with 305 additions and 0 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
__pycache__/
train.py
Empty file added README.md
Empty file.
31 changes: 31 additions & 0 deletions demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import torch
from PIL import Image
from torchvision import transforms
from model import CheckerboardAutogressive

torch.backends.cudnn.deterministic = True

if __name__ == '__main__':
device = 'cuda' if torch.cuda.is_available() else 'cpu'
checkpoint = torch.load('checkpoint.pth.tar', map_location=device)

net = CheckerboardAutogressive().to(device).eval()
net.load_state_dict(checkpoint["state_dict"])

img = Image.open('./images/kodim01.png').convert('RGB')
x = transforms.ToTensor()(img).unsqueeze(0).to(device)

with torch.no_grad():
# codec
out = net.compress(x)
rec = net.decompress(out['strings'], out['shape'])
rec = transforms.ToPILImage()(rec['x_hat'].squeeze().cpu())
rec.save('./images/codec.png', format="PNG")

# inference
out = net(x)
rec = out['x_hat'].clamp(0, 1)
rec = transforms.ToPILImage()(rec.squeeze().cpu())
rec.save('./images/infer.png', format="PNG")

print('saved in ./images')
Binary file added images/codec.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/infer.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/kodim01.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions layers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .layers import *
41 changes: 41 additions & 0 deletions layers/layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from typing import Any

import torch
import torch.nn as nn

from torch import Tensor


class CheckerboardMaskedConv2d(nn.Conv2d):
"""
if kernel_size == (5, 5)
then mask:
[[0., 1., 0., 1., 0.],
[1., 0., 1., 0., 1.],
[0., 1., 0., 1., 0.],
[1., 0., 1., 0., 1.],
[0., 1., 0., 1., 0.]]
0: non-anchor
1: anchor
"""
def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)

self.register_buffer("mask", torch.zeros_like(self.weight.data))

self.mask[:, :, 0::2, 1::2] = 1
self.mask[:, :, 1::2, 0::2] = 1

def forward(self, x: Tensor) -> Tensor:
# TODO: weight assigment is not supported by torchscript
self.weight.data *= self.mask
return super().forward(x)


if __name__ == '__main__':

# notice that the bias is 'True' in practice
ckbd = CheckerboardMaskedConv2d(3, 3, kernel_size=5, padding=2, stride=1, bias=True)
x = torch.rand((1, 3, 8, 8))

print(ckbd(x))
150 changes: 150 additions & 0 deletions model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import torch

from compressai.models.google import JointAutoregressiveHierarchicalPriors
from layers import CheckerboardMaskedConv2d
from modules import Demultiplexer, Multiplexer

class CheckerboardAutogressive(JointAutoregressiveHierarchicalPriors):
def __init__(self, N=192, M=192, **kwargs):
super().__init__(N, M, **kwargs)

self.context_prediction = CheckerboardMaskedConv2d(
M, 2 * M, kernel_size=5, padding=2, stride=1
)

def forward(self, x):
y = self.g_a(x)
z = self.h_a(y)
z_hat, z_likelihoods = self.entropy_bottleneck(z)
params = self.h_s(z_hat)

y_hat = self.gaussian_conditional.quantize(
y, "noise" if self.training else "dequantize"
)

# set non_anchor to 0
y_half = y_hat.clone()
y_half[:, :, 0::2, 0::2] = 0
y_half[:, :, 1::2, 1::2] = 0

# set anchor's ctx to 0, otherwise there will be a bias
ctx_params = self.context_prediction(y_half)
ctx_params[:, :, 0::2, 1::2] = 0
ctx_params[:, :, 1::2, 0::2] = 0

gaussian_params = self.entropy_parameters(
torch.cat((params, ctx_params), dim=1)
)
scales_hat, means_hat = gaussian_params.chunk(2, 1)
_, y_likelihoods = self.gaussian_conditional(y, scales_hat, means=means_hat)
x_hat = self.g_s(y_hat)

return {
"x_hat": x_hat,
"likelihoods": {"y": y_likelihoods, "z": z_likelihoods},
}

def compress(self, x):
y = self.g_a(x)
z = self.h_a(y)

z_strings = self.entropy_bottleneck.compress(z)
z_hat = self.entropy_bottleneck.decompress(z_strings, z.size()[-2:])

params = self.h_s(z_hat)

# Notion: in compressai, the means must be subtracted before quantification.
# In order to get y_half, we need subtract y_anchor's means and then quantize,
# to get y_anchor's means, we have to go through 'gep' here
N, _, H, W = z_hat.shape
zero_ctx_params = torch.zeros([N, 2 * self.M, H * 4, W * 4]).to(z_hat.device)
gaussian_params = self.entropy_parameters(
torch.cat((params, zero_ctx_params), dim=1)
)
_, means_hat = gaussian_params.chunk(2, 1)
y_hat = self.gaussian_conditional.quantize(y, "dequantize", means=means_hat)

# set non_anchor to 0
y_half = y_hat.clone()
y_half[:, :, 0::2, 0::2] = 0
y_half[:, :, 1::2, 1::2] = 0

# set anchor's ctx to 0, otherwise there will be a bias
ctx_params = self.context_prediction(y_half)
ctx_params[:, :, 0::2, 1::2] = 0
ctx_params[:, :, 1::2, 0::2] = 0

gaussian_params = self.entropy_parameters(
torch.cat((params, ctx_params), dim=1)
)

scales_hat, means_hat = gaussian_params.chunk(2, 1)

y_anchor, y_non_anchor = Demultiplexer(y)
scales_hat_anchor, scales_hat_non_anchor = Demultiplexer(scales_hat)
means_hat_anchor, means_hat_non_anchor = Demultiplexer(means_hat)

indexes_anchor = self.gaussian_conditional.build_indexes(scales_hat_anchor)
indexes_non_anchor = self.gaussian_conditional.build_indexes(scales_hat_non_anchor)

anchor_strings = self.gaussian_conditional.compress(y_anchor, indexes_anchor, means=means_hat_anchor)
non_anchor_strings = self.gaussian_conditional.compress(y_non_anchor, indexes_non_anchor, means=means_hat_non_anchor)

return {
"strings": [anchor_strings, non_anchor_strings, z_strings],
"shape": z.size()[-2:],
}

def decompress(self, strings, shape):
"""
See Figure 5. Illustration of the proposed two-pass decoding.
"""
assert isinstance(strings, list) and len(strings) == 3
z_hat = self.entropy_bottleneck.decompress(strings[2], shape)
params = self.h_s(z_hat)

# PASS 1: anchor
N, _, H, W = z_hat.shape
zero_ctx_params = torch.zeros([N, 2 * self.M, H * 4, W * 4]).to(z_hat.device)
gaussian_params = self.entropy_parameters(
torch.cat((params, zero_ctx_params), dim=1)
)

scales_hat, means_hat = gaussian_params.chunk(2, 1)
scales_hat_anchor, _ = Demultiplexer(scales_hat)
means_hat_anchor, _ = Demultiplexer(means_hat)

indexes_anchor = self.gaussian_conditional.build_indexes(scales_hat_anchor)
y_anchor = self.gaussian_conditional.decompress(strings[0], indexes_anchor, means=means_hat_anchor) # [1, 384, 8, 8]
y_anchor = Multiplexer(y_anchor, torch.zeros_like(y_anchor)) # [1, 192, 16, 16]

# PASS 2: non-anchor
ctx_params = self.context_prediction(y_anchor)
gaussian_params = self.entropy_parameters(
torch.cat((params, ctx_params), dim=1)
)

scales_hat, means_hat = gaussian_params.chunk(2, 1)
_, scales_hat_non_anchor = Demultiplexer(scales_hat)
_, means_hat_non_anchor = Demultiplexer(means_hat)

indexes_non_anchor = self.gaussian_conditional.build_indexes(scales_hat_non_anchor)
y_non_anchor = self.gaussian_conditional.decompress(strings[1], indexes_non_anchor, means=means_hat_non_anchor) # [1, 384, 8, 8]
y_non_anchor = Multiplexer(torch.zeros_like(y_non_anchor), y_non_anchor) # [1, 192, 16, 16]

# gather
y_hat = y_anchor + y_non_anchor
x_hat = self.g_s(y_hat).clamp_(0, 1)

return {
"x_hat": x_hat,
}


if __name__ == "__main__":
x = torch.randn([1, 3, 256, 256])
model = CheckerboardAutogressive()
model.update(force=True)

out = model.compress(x)
rec = model.decompress(out["strings"], out["shape"])
1 change: 1 addition & 0 deletions modules/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .modules import *
79 changes: 79 additions & 0 deletions modules/modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from turtle import forward
import torch.nn as nn
import torch


class Space2Depth(nn.Module):
"""
ref: https://github.com/huzi96/Coarse2Fine-PyTorch/blob/master/networks.py
"""

def __init__(self, r=2):
super().__init__()
self.r = r

def forward(self, x):
r = self.r
b, c, h, w = x.size()
out_c = c * (r**2)
out_h = h // r
out_w = w // r
x_view = x.view(b, c, out_h, r, out_w, r)
x_prime = x_view.permute(0, 3, 5, 1, 2, 4).contiguous().view(b, out_c, out_h, out_w)
return x_prime


class Depth2Space(nn.Module):
def __init__(self, r=2):
super().__init__()
self.r = r

def forward(self, x):
r = self.r
b, c, h, w = x.size()
out_c = c // (r**2)
out_h = h * r
out_w = w * r
x_view = x.view(b, r, r, out_c, h, w)
x_prime = x_view.permute(0, 3, 4, 1, 5, 2).contiguous().view(b, out_c, out_h, out_w)
return x_prime


def Demultiplexer(x):
"""
See Supplementary Material: Figure 2
"""
x_prime = Space2Depth(r=2)(x)

_, C, _, _ = x_prime.shape
anchor_index = tuple(range(C // 4, C * 3 // 4))
non_anchor_index = tuple(range(0, C // 4)) + tuple(range(C * 3 // 4, C))

anchor = x_prime[:, anchor_index, :, :]
non_anchor = x_prime[:, non_anchor_index, :, :]

return anchor, non_anchor

def Multiplexer(anchor, non_anchor):
"""
The inverse opperation of Demultiplexer
"""
_, C, _, _ = non_anchor.shape
x_prime = torch.cat((non_anchor[:, : C//2, :, :], anchor, non_anchor[:, C//2:, :, :]), dim=1)
return Depth2Space(r=2)(x_prime)


if __name__ == '__main__':
x = torch.zeros(1, 1, 6, 6)
x[0, 0, 0, 0] = 0
x[0, 0, 0, 1] = 1
x[0, 0, 1, 0] = 2
x[0, 0, 1, 1] = 3
print(x)

anchor, non_anchor = Demultiplexer(x)
print(anchor)
print(non_anchor)

x = Multiplexer(anchor, non_anchor)
print(x)

0 comments on commit a8e3ecc

Please sign in to comment.