-
Notifications
You must be signed in to change notification settings - Fork 0
/
diffusion.py
175 lines (157 loc) · 7.89 KB
/
diffusion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import torch
import torch.nn as nn
from tqdm import tqdm
import torch.nn.functional as F
class GaussianDiffusion(nn.Module):
def __init__(self, dtype, model, betas, w, v, device):
super().__init__()
self.dtype = dtype
self.model = model.to(device)
self.model.dtype = self.dtype
self.betas = torch.tensor(betas,dtype=self.dtype)
self.w = w
self.v = v
self.T = len(betas)
self.device = device
self.alphas = 1 - self.betas
self.log_alphas = torch.log(self.alphas)
self.log_alphas_bar = torch.cumsum(self.log_alphas, dim = 0)
self.alphas_bar = torch.exp(self.log_alphas_bar)
# self.alphas_bar = torch.cumprod(self.alphas, dim = 0)
self.log_alphas_bar_prev = F.pad(self.log_alphas_bar[:-1],[1,0],'constant', 0)
self.alphas_bar_prev = torch.exp(self.log_alphas_bar_prev)
self.log_one_minus_alphas_bar_prev = torch.log(1.0 - self.alphas_bar_prev)
# self.alphas_bar_prev = F.pad(self.alphas_bar[:-1],[1,0],'constant',1)
# calculate parameters for q(x_t|x_{t-1})
self.log_sqrt_alphas = 0.5 * self.log_alphas
self.sqrt_alphas = torch.exp(self.log_sqrt_alphas)
# self.sqrt_alphas = torch.sqrt(self.alphas)
# calculate parameters for q(x_t|x_0)
self.log_sqrt_alphas_bar = 0.5 * self.log_alphas_bar
self.sqrt_alphas_bar = torch.exp(self.log_sqrt_alphas_bar)
# self.sqrt_alphas_bar = torch.sqrt(self.alphas_bar)
self.log_one_minus_alphas_bar = torch.log(1.0 - self.alphas_bar)
self.sqrt_one_minus_alphas_bar = torch.exp(0.5 * self.log_one_minus_alphas_bar)
# calculate parameters for q(x_{t-1}|x_t,x_0)
# log calculation clipped because the \tilde{\beta} = 0 at the beginning
self.tilde_betas = self.betas * torch.exp(self.log_one_minus_alphas_bar_prev - self.log_one_minus_alphas_bar)
self.log_tilde_betas_clipped = torch.log(torch.cat((self.tilde_betas[1].view(-1), self.tilde_betas[1:]), 0))
self.mu_coef_x0 = self.betas * torch.exp(0.5 * self.log_alphas_bar_prev - self.log_one_minus_alphas_bar)
self.mu_coef_xt = torch.exp(0.5 * self.log_alphas + self.log_one_minus_alphas_bar_prev - self.log_one_minus_alphas_bar)
self.vars = torch.cat((self.tilde_betas[1:2],self.betas[1:]), 0)
self.coef1 = torch.exp(-self.log_sqrt_alphas)
self.coef2 = self.coef1 * self.betas / self.sqrt_one_minus_alphas_bar
# calculate parameters for predicted x_0
self.sqrt_recip_alphas_bar = torch.exp(-self.log_sqrt_alphas_bar)
# self.sqrt_recip_alphas_bar = torch.sqrt(1.0 / self.alphas_bar)
self.sqrt_recipm1_alphas_bar = torch.exp(self.log_one_minus_alphas_bar - self.log_sqrt_alphas_bar)
# self.sqrt_recipm1_alphas_bar = torch.sqrt(1.0 / self.alphas_bar - 1)
@staticmethod
def _extract(coef, t, x_shape):
"""
input:
coef : an array
t : timestep
x_shape : the shape of tensor x that has K dims(the value of first dim is batch size)
output:
a tensor of shape [batchsize,1,...] where the length has K dims.
"""
assert t.shape[0] == x_shape[0]
neo_shape = torch.ones_like(torch.tensor(x_shape))
neo_shape[0] = x_shape[0]
neo_shape = neo_shape.tolist()
chosen = coef[t]
chosen = chosen.to(t.device)
return chosen.reshape(neo_shape)
def q_mean_variance(self, x_0, t):
"""
calculate the parameters of q(x_t|x_0)
"""
mean = self._extract(self.sqrt_alphas_bar, t, x_0.shape) * x_0
var = self._extract(1.0 - self.sqrt_alphas_bar, t, x_0.shape)
return mean, var
def q_sample(self, x_0, t):
"""
sample from q(x_t|x_0)
"""
eps = torch.randn_like(x_0, requires_grad=False)
return self._extract(self.sqrt_alphas_bar, t, x_0.shape) * x_0 \
+ self._extract(self.sqrt_one_minus_alphas_bar, t, x_0.shape) * eps, eps
def q_posterior_mean_variance(self, x_0, x_t, t):
"""
calculate the parameters of q(x_{t-1}|x_t,x_0)
"""
posterior_mean = self._extract(self.mu_coef_x0, t, x_0.shape) * x_0 \
+ self._extract(self.mu_coef_xt, t, x_t.shape) * x_t
posterior_var_max = self._extract(self.tilde_betas, t, x_t.shape)
log_posterior_var_min = self._extract(self.log_tilde_betas_clipped, t, x_t.shape)
log_posterior_var_max = self._extract(torch.log(self.betas), t, x_t.shape)
log_posterior_var = self.v * log_posterior_var_max + (1 - self.v) * log_posterior_var_min
neo_posterior_var = torch.exp(log_posterior_var)
return posterior_mean, posterior_var_max, neo_posterior_var
def p_mean_variance(self, x_t, t, **model_kwargs):
"""
calculate the parameters of p_{theta}(x_{t-1}|x_t)
"""
if model_kwargs == None:
model_kwargs = {}
B, C = x_t.shape[:2]
assert t.shape == (B,)
cemb_shape = model_kwargs['cemb'].shape
pred_eps_cond = self.model(x_t, t, **model_kwargs)
model_kwargs['cemb'] = torch.zeros(cemb_shape, device = self.device)
pred_eps_uncond = self.model(x_t, t, **model_kwargs)
pred_eps = (1 + self.w) * pred_eps_cond - self.w * pred_eps_uncond
assert torch.isnan(x_t).int().sum() == 0, f"nan in tensor x_t when t = {t[0]}"
assert torch.isnan(t).int().sum() == 0, f"nan in tensor t when t = {t[0]}"
assert torch.isnan(pred_eps).int().sum() == 0, f"nan in tensor pred_eps when t = {t[0]}"
p_mean = self._predict_xt_prev_mean_from_eps(x_t, t.type(dtype=torch.long), pred_eps)
p_var = self._extract(self.vars, t.type(dtype=torch.long), x_t.shape)
return p_mean, p_var
def _predict_x0_from_eps(self, x_t, t, eps):
return self._extract(coef = self.sqrt_recip_alphas_bar, t = t, x_shape = x_t.shape) \
* x_t - self._extract(coef = self.sqrt_one_minus_alphas_bar, t = t, x_shape = x_t.shape) * eps
def _predict_xt_prev_mean_from_eps(self, x_t, t, eps):
return self._extract(coef = self.coef1, t = t, x_shape = x_t.shape) * x_t - \
self._extract(coef = self.coef2, t = t, x_shape = x_t.shape) * eps
def p_sample(self, x_t, t, **model_kwargs):
"""
sample x_{t-1} from p_{theta}(x_{t-1}|x_t)
"""
if model_kwargs == None:
model_kwargs = {}
B, C = x_t.shape[:2]
assert t.shape == (B,), f"size of t is not batch size {B}"
mean, var = self.p_mean_variance(x_t , t, **model_kwargs)
assert torch.isnan(mean).int().sum() == 0, f"nan in tensor mean when t = {t[0]}"
assert torch.isnan(var).int().sum() == 0, f"nan in tensor var when t = {t[0]}"
noise = torch.randn_like(x_t)
noise[t <= 0] = 0
return mean + torch.sqrt(var) * noise
def sample(self, shape, **model_kwargs):
"""
sample images from p_{theta}
"""
print('Start generating...')
if model_kwargs == None:
model_kwargs = {}
x_t = torch.randn(shape, device = self.device)
tlist = torch.ones([x_t.shape[0]], device = self.device) * self.T
for _ in tqdm(range(self.T),dynamic_ncols=True):
tlist -= 1
with torch.no_grad():
x_t = self.p_sample(x_t, tlist, **model_kwargs)
x_t = torch.clamp(x_t, -1, 1)
print('ending sampling process...')
return x_t
def trainloss(self, x_0, **model_kwargs):
"""
calculate the loss of denoising diffusion probabilistic model
"""
if model_kwargs == None:
model_kwargs = {}
t = torch.randint(self.T, size = (x_0.shape[0],), device=self.device)
x_t, eps = self.q_sample(x_0, t)
pred_eps = self.model(x_t, t, **model_kwargs)
loss = F.mse_loss(pred_eps, eps, reduction='mean')
return loss