diff --git a/diffusion.py b/diffusion.py index 3830a92..032a441 100644 --- a/diffusion.py +++ b/diffusion.py @@ -5,15 +5,16 @@ def to_numpy(tensor): return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() -def compare_diff(a,b,decimal): +def compare_diff_betw_a_b(a,b,decimal): import numpy as np return np.testing.assert_almost_equal(to_numpy(a),to_numpy(b), decimal=decimal) class GaussianDiffusion(nn.Module): - def __init__(self, dtype, model, classifier,betas, w, v, device): + def __init__(self, dtype, model, classifier,cemblayer,betas, w, v, device): super().__init__() self.dtype = dtype self.model = model.to(device) self.classifier=classifier + self.cemblayer=cemblayer self.model.dtype = self.dtype self.betas = torch.tensor(betas,dtype=self.dtype).to(device) self.w = w @@ -184,59 +185,55 @@ def p_sample(self, x_t, t,return_all=True, **model_kwargs): ######## ######## ######## ########增加classifier的处理方法 ######## ######## ######## ######## ######## - def p_mean_variance_with_classifier(self, x_t, t, **model_kwargs): - """ - calculate the parameters of p_{theta}(x_{t-1}|x_t) - """ - if model_kwargs == None: - model_kwargs = {} - B, C = x_t.shape[:2] - assert t.shape == (B,) - - # cemb_shape = model_kwargs['cemb'].shape - # pred_eps_cond = self.model(x_t, t, **model_kwargs) - # model_kwargs['cemb'] = torch.zeros(cemb_shape, device = self.device) - # pred_eps_uncond = self.model(x_t, t, **model_kwargs) - # pred_eps = (1 + self.w) * pred_eps_cond - self.w * pred_eps_uncond - pred_eps=self.model(x_t,t,**model_kwargs) + def calc_diff(self,x_t,t,use_classifier=True,use_sofrmax=True): + #assert x_t.size(0)==1 + conditions=torch.arange(0,10,1).to('cuda') + if use_classifier: + logits=self.classifier(x_t,t) + if use_sofrmax: + #scores = F.log_softmax(logits, dim=-1) + scores = F.softmax(logits, dim=-1)[0] + else: + scores=logits + else: + scores=torch.ones_like(conditions)*0.1 - assert torch.isnan(x_t).int().sum() == 0, f"nan in tensor x_t when t = {t[0]}" - assert torch.isnan(t).int().sum() == 0, f"nan in tensor t when t = {t[0]}" - assert torch.isnan(pred_eps).int().sum() == 0, f"nan in tensor pred_eps when t = {t[0]}" - p_mean = self._predict_xt_prev_mean_from_eps(x_t, t.type(dtype=torch.long), pred_eps) - p_var = self._extract(self.vars, t.type(dtype=torch.long), x_t.shape) - return p_mean, p_var - - def p_mean_variance_for_compare(self, x_t, t,**model_kwargs): + pred_eps=scores[0]*self.model(x_t,t,self.cemblayer(conditions[0].repeat(x_t.size(0)))) + for i,label in enumerate(conditions[1:]): + c=self.cemblayer(label.repeat(x_t.size(0))) + pred_eps+=scores[i]*self.model(x_t,t,c) + cemb=torch.zeros(size=(x_t.size(0),10), device = self.device) + pred_eps_unc= self.model(x_t, t,cemb) + return pred_eps,pred_eps_unc + def p_mean_variance_for_compare(self, x_t, t,compare=False,**model_kwargs): """ calculate the parameters of p_{theta}(x_{t-1}|x_t) - - - 修改为输入y=None,模型输出的epsilon为 res=e(x_t,t,None) """ if model_kwargs == None: model_kwargs = {} B, C = x_t.shape[:2] assert t.shape == (B,) - #cemb_shape = model_kwargs['cemb'].shape - #pred_eps_cond = self.model(x_t, t, **model_kwargs) - # cemb_shape=torch.zeros(size=(x_t.size(0),10)) - model_kwargs['cemb'] = torch.zeros(size=(x_t.size(0),10), device = self.device) - pred_eps_uncond = self.model(x_t, t, **model_kwargs) - pred_eps=pred_eps_uncond - #pred_eps = (1 + self.w) * pred_eps_cond - self.w * pred_eps_uncond + if not compare: + cemb_shape = model_kwargs['cemb'].shape + pred_eps_cond = self.model(x_t, t, **model_kwargs) - assert torch.isnan(x_t).int().sum() == 0, f"nan in tensor x_t when t = {t[0]}" - assert torch.isnan(t).int().sum() == 0, f"nan in tensor t when t = {t[0]}" - assert torch.isnan(pred_eps).int().sum() == 0, f"nan in tensor pred_eps when t = {t[0]}" - p_mean = self._predict_xt_prev_mean_from_eps(x_t, t.type(dtype=torch.long), pred_eps) - p_var = self._extract(self.vars, t.type(dtype=torch.long), x_t.shape) + model_kwargs['cemb'] = torch.zeros(cemb_shape, device = self.device) + pred_eps_uncond = self.model(x_t, t, **model_kwargs) + pred_eps = (1 + self.w) * pred_eps_cond - self.w * pred_eps_uncond + + assert torch.isnan(x_t).int().sum() == 0, f"nan in tensor x_t when t = {t[0]}" + assert torch.isnan(t).int().sum() == 0, f"nan in tensor t when t = {t[0]}" + assert torch.isnan(pred_eps).int().sum() == 0, f"nan in tensor pred_eps when t = {t[0]}" + p_mean = self._predict_xt_prev_mean_from_eps(x_t, t.type(dtype=torch.long), pred_eps) + p_var = self._extract(self.vars, t.type(dtype=torch.long), x_t.shape) + # else: + return p_mean, p_var + def p_sample_for_compare(self,x_t, t,return_all=True,**model_kwargs): """ sample x_{t-1} from p_{theta}(x_{t-1}|x_t) """ - B, C = x_t.shape[:2] assert t.shape == (B,), f"size of t is not batch size {B}" mean, var = self.p_mean_variance_for_compare(x_t , t) @@ -266,44 +263,68 @@ def condition_mean(self, p_mean,p_var, x, t,sum_type='prob',classifier_scale=1.0 gradient = self.cond_fn(self.classifier,x, t, classifier_scale=classifier_scale,y=y_list[0]) for y_i in y_list[1:]: gradient += self.cond_fn(self.classifier,x, t, classifier_scale=classifier_scale,y=y_i) - new_mean = ( p_mean.float() + p_var * gradient.float() ) return new_mean - def compare_cond_uncond_diff(self,shape,compare_t,**model_kwargs):#主函数 sample + # def compare_cond_uncond_diff(self,shape,compare_t,**model_kwargs):#主函数 sample + # print('Start generating...') + # if model_kwargs == None: + # model_kwargs = {} + # x_t = torch.randn(shape, device = self.device) + # tlist = torch.ones([x_t.shape[0]], device = self.device) * self.T + # for _ in tqdm(range(self.T),dynamic_ncols=True): + # tlist -= 1 + # if not isinstance(compare_t,list): + # compare_t=[compare_t] + # if tlist[0] in compare_t: + # #计算新的mean,并对齐进行累加 + # with torch.no_grad(): + # #x_t_with_cond=self.p_sample_for_compare(x_t, tlist,**model_kwargs) + # noise = torch.randn_like(x_t) + # noise[tlist[0] <= 0] = 0 + # mean_cond,var_mean=self.p_sample_for_compare(x_t, tlist,return_all=False,**model_kwargs) + # mean_uc,var_uc=self.p_sample(x_t, tlist,return_all=False,**model_kwargs) + + # x_t_with_cond=mean_cond+torch.sqrt(var_mean) * noise + # x_t_no_cond=mean_uc+torch.sqrt(var_uc) * noise + # try: + # res=compare_diff(x_t_no_cond,x_t_with_cond,decimal=4) + # print(res) + # except Exception as e: + # print(e) + # x_t=x_t_no_cond + # else: + # with torch.no_grad(): + # x_t = self.p_sample(x_t, tlist,**model_kwargs) #没有classifier的处理 + + # x_t = torch.clamp(x_t, -1, 1) + # print('ending sampling process...') + # return x_t + def compare_cond_uncond_diff(self,shape,compare_t,clear_compare_results=False,**model_kwargs):#主函数 sample print('Start generating...') + logger_list=[] if model_kwargs == None: model_kwargs = {} x_t = torch.randn(shape, device = self.device) tlist = torch.ones([x_t.shape[0]], device = self.device) * self.T for _ in tqdm(range(self.T),dynamic_ncols=True): tlist -= 1 - if not isinstance(compare_t,list): compare_t=[compare_t] if tlist[0] in compare_t: - #计算新的mean,并对齐进行累加 - with torch.no_grad(): - #x_t_with_cond=self.p_sample_for_compare(x_t, tlist,**model_kwargs) - noise = torch.randn_like(x_t) - noise[tlist[0] <= 0] = 0 - mean_cond,var_mean=self.p_sample_for_compare(x_t, tlist,return_all=False,**model_kwargs) - mean_uc,var_uc=self.p_sample(x_t, tlist,return_all=False,**model_kwargs) - - x_t_with_cond=mean_cond+torch.sqrt(var_mean) * noise - x_t_no_cond=mean_uc+torch.sqrt(var_uc) * noise + sum_eps_condition,eps_uncond=self.calc_diff(x_t,tlist) + logger_list.append(abs(sum_eps_condition.cpu().numpy()-eps_uncond.cpu().numpy()).mean()) + if clear_compare_results: try: - res=compare_diff(x_t_no_cond,x_t_with_cond,decimal=4) - print(res) - except Exception as e: + res=compare_diff_betw_a_b(sum_eps_condition,eps_uncond,decimal=3) + print('diff in time step :{} :\n {}'.format(tlist,res)) + except Exception as e: print(e) - x_t=x_t_no_cond - - else: - with torch.no_grad(): - x_t = self.p_sample(x_t, tlist,**model_kwargs) #没有classifier的处理 + + #直接多调用一次了。。 + x_t = self.p_sample(x_t, tlist,**model_kwargs) #没有classifier的处理 x_t = torch.clamp(x_t, -1, 1) print('ending sampling process...') - return x_t \ No newline at end of file + return x_t,logger_list \ No newline at end of file diff --git a/sample_96_pict_1.8.png b/sample_96_pict_1.8.png index 59ba8b0..7b7d79a 100644 Binary files a/sample_96_pict_1.8.png and b/sample_96_pict_1.8.png differ diff --git a/verification_sample.py b/verification_sample.py index bf66047..dcea3ad 100644 --- a/verification_sample.py +++ b/verification_sample.py @@ -60,7 +60,8 @@ def sample(params): w = params.w, v = params.v, device = params.device, - classifier=classifier) + classifier=classifier, + cemblayer=cemblayer) # eval mode diffusion.model.eval() cemblayer.eval() @@ -76,12 +77,15 @@ def sample(params): #lab=torch.tensor([1]).to(params.device) # get label embeddings cemb = cemblayer(lab) + genshape = (params.batchsize, 3, 32, 32) # generated = diffusion.sample(genshape, cemb = cemb) - generated=diffusion.compare_cond_uncond_diff(genshape,compare_t=999, cemb = cemb) - print(generated) - + generated,logger_list=diffusion.compare_cond_uncond_diff(genshape,compare_t=list(range(1000)), cemb = cemb) + import matplotlib.pyplot as plt + f=plt.figure() + plt.plot(list(range(1000)), logger_list) + f.savefig('z.jpg') #transform samples into images img = transback(generated) # save images diff --git a/z.jpg b/z.jpg new file mode 100644 index 0000000..988c7bf Binary files /dev/null and b/z.jpg differ