Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanghui1 committed Mar 19, 2023
1 parent 2bfb164 commit 2f28344
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 67 deletions.
147 changes: 84 additions & 63 deletions diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@

def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
def compare_diff(a,b,decimal):
def compare_diff_betw_a_b(a,b,decimal):
import numpy as np
return np.testing.assert_almost_equal(to_numpy(a),to_numpy(b), decimal=decimal)
class GaussianDiffusion(nn.Module):
def __init__(self, dtype, model, classifier,betas, w, v, device):
def __init__(self, dtype, model, classifier,cemblayer,betas, w, v, device):
super().__init__()
self.dtype = dtype
self.model = model.to(device)
self.classifier=classifier
self.cemblayer=cemblayer
self.model.dtype = self.dtype
self.betas = torch.tensor(betas,dtype=self.dtype).to(device)
self.w = w
Expand Down Expand Up @@ -184,59 +185,55 @@ def p_sample(self, x_t, t,return_all=True, **model_kwargs):


######## ######## ######## ########增加classifier的处理方法 ######## ######## ######## ######## ########
def p_mean_variance_with_classifier(self, x_t, t, **model_kwargs):
"""
calculate the parameters of p_{theta}(x_{t-1}|x_t)
"""
if model_kwargs == None:
model_kwargs = {}
B, C = x_t.shape[:2]
assert t.shape == (B,)

# cemb_shape = model_kwargs['cemb'].shape
# pred_eps_cond = self.model(x_t, t, **model_kwargs)
# model_kwargs['cemb'] = torch.zeros(cemb_shape, device = self.device)
# pred_eps_uncond = self.model(x_t, t, **model_kwargs)
# pred_eps = (1 + self.w) * pred_eps_cond - self.w * pred_eps_uncond
pred_eps=self.model(x_t,t,**model_kwargs)
def calc_diff(self,x_t,t,use_classifier=True,use_sofrmax=True):
#assert x_t.size(0)==1
conditions=torch.arange(0,10,1).to('cuda')
if use_classifier:
logits=self.classifier(x_t,t)
if use_sofrmax:
#scores = F.log_softmax(logits, dim=-1)
scores = F.softmax(logits, dim=-1)[0]
else:
scores=logits
else:
scores=torch.ones_like(conditions)*0.1

assert torch.isnan(x_t).int().sum() == 0, f"nan in tensor x_t when t = {t[0]}"
assert torch.isnan(t).int().sum() == 0, f"nan in tensor t when t = {t[0]}"
assert torch.isnan(pred_eps).int().sum() == 0, f"nan in tensor pred_eps when t = {t[0]}"
p_mean = self._predict_xt_prev_mean_from_eps(x_t, t.type(dtype=torch.long), pred_eps)
p_var = self._extract(self.vars, t.type(dtype=torch.long), x_t.shape)
return p_mean, p_var

def p_mean_variance_for_compare(self, x_t, t,**model_kwargs):
pred_eps=scores[0]*self.model(x_t,t,self.cemblayer(conditions[0].repeat(x_t.size(0))))
for i,label in enumerate(conditions[1:]):
c=self.cemblayer(label.repeat(x_t.size(0)))
pred_eps+=scores[i]*self.model(x_t,t,c)
cemb=torch.zeros(size=(x_t.size(0),10), device = self.device)
pred_eps_unc= self.model(x_t, t,cemb)
return pred_eps,pred_eps_unc
def p_mean_variance_for_compare(self, x_t, t,compare=False,**model_kwargs):
"""
calculate the parameters of p_{theta}(x_{t-1}|x_t)
修改为输入y=None,模型输出的epsilon为 res=e(x_t,t,None)
"""
if model_kwargs == None:
model_kwargs = {}
B, C = x_t.shape[:2]
assert t.shape == (B,)
#cemb_shape = model_kwargs['cemb'].shape
#pred_eps_cond = self.model(x_t, t, **model_kwargs)
# cemb_shape=torch.zeros(size=(x_t.size(0),10))
model_kwargs['cemb'] = torch.zeros(size=(x_t.size(0),10), device = self.device)
pred_eps_uncond = self.model(x_t, t, **model_kwargs)
pred_eps=pred_eps_uncond
#pred_eps = (1 + self.w) * pred_eps_cond - self.w * pred_eps_uncond
if not compare:
cemb_shape = model_kwargs['cemb'].shape
pred_eps_cond = self.model(x_t, t, **model_kwargs)

assert torch.isnan(x_t).int().sum() == 0, f"nan in tensor x_t when t = {t[0]}"
assert torch.isnan(t).int().sum() == 0, f"nan in tensor t when t = {t[0]}"
assert torch.isnan(pred_eps).int().sum() == 0, f"nan in tensor pred_eps when t = {t[0]}"
p_mean = self._predict_xt_prev_mean_from_eps(x_t, t.type(dtype=torch.long), pred_eps)
p_var = self._extract(self.vars, t.type(dtype=torch.long), x_t.shape)
model_kwargs['cemb'] = torch.zeros(cemb_shape, device = self.device)
pred_eps_uncond = self.model(x_t, t, **model_kwargs)
pred_eps = (1 + self.w) * pred_eps_cond - self.w * pred_eps_uncond

assert torch.isnan(x_t).int().sum() == 0, f"nan in tensor x_t when t = {t[0]}"
assert torch.isnan(t).int().sum() == 0, f"nan in tensor t when t = {t[0]}"
assert torch.isnan(pred_eps).int().sum() == 0, f"nan in tensor pred_eps when t = {t[0]}"
p_mean = self._predict_xt_prev_mean_from_eps(x_t, t.type(dtype=torch.long), pred_eps)
p_var = self._extract(self.vars, t.type(dtype=torch.long), x_t.shape)
# else:

return p_mean, p_var

def p_sample_for_compare(self,x_t, t,return_all=True,**model_kwargs):
"""
sample x_{t-1} from p_{theta}(x_{t-1}|x_t)
"""

B, C = x_t.shape[:2]
assert t.shape == (B,), f"size of t is not batch size {B}"
mean, var = self.p_mean_variance_for_compare(x_t , t)
Expand Down Expand Up @@ -266,44 +263,68 @@ def condition_mean(self, p_mean,p_var, x, t,sum_type='prob',classifier_scale=1.0
gradient = self.cond_fn(self.classifier,x, t, classifier_scale=classifier_scale,y=y_list[0])
for y_i in y_list[1:]:
gradient += self.cond_fn(self.classifier,x, t, classifier_scale=classifier_scale,y=y_i)

new_mean = (
p_mean.float() + p_var * gradient.float()
)
return new_mean
def compare_cond_uncond_diff(self,shape,compare_t,**model_kwargs):#主函数 sample
# def compare_cond_uncond_diff(self,shape,compare_t,**model_kwargs):#主函数 sample
# print('Start generating...')
# if model_kwargs == None:
# model_kwargs = {}
# x_t = torch.randn(shape, device = self.device)
# tlist = torch.ones([x_t.shape[0]], device = self.device) * self.T
# for _ in tqdm(range(self.T),dynamic_ncols=True):
# tlist -= 1
# if not isinstance(compare_t,list):
# compare_t=[compare_t]
# if tlist[0] in compare_t:
# #计算新的mean,并对齐进行累加
# with torch.no_grad():
# #x_t_with_cond=self.p_sample_for_compare(x_t, tlist,**model_kwargs)
# noise = torch.randn_like(x_t)
# noise[tlist[0] <= 0] = 0
# mean_cond,var_mean=self.p_sample_for_compare(x_t, tlist,return_all=False,**model_kwargs)
# mean_uc,var_uc=self.p_sample(x_t, tlist,return_all=False,**model_kwargs)

# x_t_with_cond=mean_cond+torch.sqrt(var_mean) * noise
# x_t_no_cond=mean_uc+torch.sqrt(var_uc) * noise
# try:
# res=compare_diff(x_t_no_cond,x_t_with_cond,decimal=4)
# print(res)
# except Exception as e:
# print(e)
# x_t=x_t_no_cond
# else:
# with torch.no_grad():
# x_t = self.p_sample(x_t, tlist,**model_kwargs) #没有classifier的处理

# x_t = torch.clamp(x_t, -1, 1)
# print('ending sampling process...')
# return x_t
def compare_cond_uncond_diff(self,shape,compare_t,clear_compare_results=False,**model_kwargs):#主函数 sample
print('Start generating...')
logger_list=[]
if model_kwargs == None:
model_kwargs = {}
x_t = torch.randn(shape, device = self.device)
tlist = torch.ones([x_t.shape[0]], device = self.device) * self.T
for _ in tqdm(range(self.T),dynamic_ncols=True):
tlist -= 1

if not isinstance(compare_t,list):
compare_t=[compare_t]
if tlist[0] in compare_t:
#计算新的mean,并对齐进行累加
with torch.no_grad():
#x_t_with_cond=self.p_sample_for_compare(x_t, tlist,**model_kwargs)
noise = torch.randn_like(x_t)
noise[tlist[0] <= 0] = 0
mean_cond,var_mean=self.p_sample_for_compare(x_t, tlist,return_all=False,**model_kwargs)
mean_uc,var_uc=self.p_sample(x_t, tlist,return_all=False,**model_kwargs)

x_t_with_cond=mean_cond+torch.sqrt(var_mean) * noise
x_t_no_cond=mean_uc+torch.sqrt(var_uc) * noise
sum_eps_condition,eps_uncond=self.calc_diff(x_t,tlist)
logger_list.append(abs(sum_eps_condition.cpu().numpy()-eps_uncond.cpu().numpy()).mean())
if clear_compare_results:
try:
res=compare_diff(x_t_no_cond,x_t_with_cond,decimal=4)
print(res)
except Exception as e:
res=compare_diff_betw_a_b(sum_eps_condition,eps_uncond,decimal=3)
print('diff in time step :{} :\n {}'.format(tlist,res))
except Exception as e:
print(e)
x_t=x_t_no_cond

else:
with torch.no_grad():
x_t = self.p_sample(x_t, tlist,**model_kwargs) #没有classifier的处理

#直接多调用一次了。。
x_t = self.p_sample(x_t, tlist,**model_kwargs) #没有classifier的处理

x_t = torch.clamp(x_t, -1, 1)
print('ending sampling process...')
return x_t
return x_t,logger_list
Binary file modified sample_96_pict_1.8.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
12 changes: 8 additions & 4 deletions verification_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ def sample(params):
w = params.w,
v = params.v,
device = params.device,
classifier=classifier)
classifier=classifier,
cemblayer=cemblayer)
# eval mode
diffusion.model.eval()
cemblayer.eval()
Expand All @@ -76,12 +77,15 @@ def sample(params):
#lab=torch.tensor([1]).to(params.device)
# get label embeddings
cemb = cemblayer(lab)

genshape = (params.batchsize, 3, 32, 32)
# generated = diffusion.sample(genshape, cemb = cemb)
generated=diffusion.compare_cond_uncond_diff(genshape,compare_t=999, cemb = cemb)
print(generated)

generated,logger_list=diffusion.compare_cond_uncond_diff(genshape,compare_t=list(range(1000)), cemb = cemb)

import matplotlib.pyplot as plt
f=plt.figure()
plt.plot(list(range(1000)), logger_list)
f.savefig('z.jpg')
#transform samples into images
img = transback(generated)
# save images
Expand Down
Binary file added z.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 2f28344

Please sign in to comment.