From a2b5e181b0f765bec4b902aa9bee4c2bfa5d29f1 Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Mon, 4 Dec 2023 12:50:20 +0100 Subject: [PATCH] update data sampler --- src/dp_cgans/data_sampler.py | 8 ++- src/dp_cgans/synthesizers/dp_cgan.py | 99 +++++++++++++++++++--------- 2 files changed, 73 insertions(+), 34 deletions(-) diff --git a/src/dp_cgans/data_sampler.py b/src/dp_cgans/data_sampler.py index f622ca9..84c8a0c 100644 --- a/src/dp_cgans/data_sampler.py +++ b/src/dp_cgans/data_sampler.py @@ -98,7 +98,8 @@ def is_discrete_column(column_info): self._categories_each_column.append(column_info[0].dim) self.get_position.append(position_cnt) position_cnt += 1 - max_category = max(self._categories_each_column) + + max_category = max(self._categories_each_column, default=0) self._discrete_column_cond_st = np.zeros(n_discrete_columns, dtype='int32') self._discrete_column_n_category = np.zeros( @@ -116,7 +117,10 @@ def is_discrete_column(column_info): self._n_categories = sum(self._categories_each_column) self._categories_each_column = np.array(self._categories_each_column) - second_max_category = np.partition(self._categories_each_column.flatten(), -2)[-2] + if len(self._categories_each_column)>0: # if self._categories_each_column is not empty + second_max_category = np.partition(self._categories_each_column.flatten(), -2)[-2] + else: + second_max_category = 0 self._discrete_pair_cond_st = np.zeros((int(((n_discrete_columns)*(n_discrete_columns-1))/2),int((max_category+1) * (second_max_category+1))),dtype='int32') self._discrete_pair_n_category = np.zeros(int(((n_discrete_columns)*(n_discrete_columns-1))/2), dtype='int32') diff --git a/src/dp_cgans/synthesizers/dp_cgan.py b/src/dp_cgans/synthesizers/dp_cgan.py index dd97f8f..51f3b8c 100644 --- a/src/dp_cgans/synthesizers/dp_cgan.py +++ b/src/dp_cgans/synthesizers/dp_cgan.py @@ -11,13 +11,12 @@ from packaging import version from torch import optim from torch.nn import BatchNorm1d, Dropout, LeakyReLU, Linear, Module, ReLU, Sequential, functional, BCEWithLogitsLoss, utils +from tqdm import tqdm from dp_cgans.data_sampler import DataSampler from dp_cgans.data_transformer import DataTransformer from dp_cgans.synthesizers.base import BaseSynthesizer -import scipy.stats - ######## ADDED ######## from datetime import datetime from contextlib import redirect_stdout @@ -60,9 +59,8 @@ def calc_gradient_penalty(self, real_data, fake_data, device='cpu', pac=10, lamb create_graph=True, retain_graph=True, only_inputs=True )[0] - gradient_penalty = (( - gradients.view(-1, pac * real_data.size(1)).norm(2, dim=1) - 1 - ) ** 2).mean() * lambda_ + gradients_view = gradients.view(-1, pac * real_data.size(1)).norm(2, dim=1) - 1 + gradient_penalty = ((gradients_view) ** 2).mean() * lambda_ return gradient_penalty @@ -198,6 +196,8 @@ def __init__(self, embedding_dim=128, generator_dim=(256, 256), discriminator_di self._data_sampler = None self._generator = None self._discriminator = None + + self.loss_values = pd.DataFrame(columns=['Epoch', 'Generator Loss', 'Distriminator Loss']) @staticmethod @@ -245,7 +245,7 @@ def _apply_activate(self, data): data_t.append(transformed) st = ed else: - assert 0 + raise ValueError(f'Unexpected activation function {span_info.activation_fn}.') return torch.cat(data_t, dim=1) @@ -463,12 +463,20 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=None): mean = torch.zeros(self._batch_size, self._embedding_dim, device=self._device) std = mean + 1 + self.loss_values = pd.DataFrame(columns=['Epoch', 'Generator Loss', 'Distriminator Loss']) + + epoch_iterator = tqdm(range(epochs), disable=(not self._verbose)) + if self._verbose: + description = 'Gen. ({gen:.2f}) | Discrim. ({dis:.2f})' + epoch_iterator.set_description(description.format(gen=0, dis=0)) + + steps_per_epoch = max(len(train_data) // self._batch_size, 1) ######## ADDED ######## with open('loss_output_%s.txt'%str(epochs), 'w') as f: with redirect_stdout(f): ######## ADDED ######## - for i in range(epochs): + for i in epoch_iterator: for id_ in range(steps_per_epoch): for n in range(self._discriminator_steps): @@ -611,22 +619,46 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=None): loss_g = -torch.mean(y_fake) + cross_entropy_pair # + rules_penalty - optimizerG.zero_grad() + optimizerG.zero_grad(set_to_none=False) loss_g.backward() optimizerG.step() + + generator_loss = loss_g.detach().cpu() + discriminator_loss = loss_d.detach().cpu() + + + epoch_loss_df = pd.DataFrame({ + 'Epoch': [i], + 'Generator Loss': [generator_loss], + 'Discriminator Loss': [discriminator_loss] + }) + if not self.loss_values.empty: + self.loss_values = pd.concat( + [self.loss_values, epoch_loss_df] + ).reset_index(drop=True) + else: + self.loss_values = epoch_loss_df + + if self._verbose: ######## ADDED ######## - now = datetime.now() - current_time = now.strftime("%H:%M:%S") + # now = datetime.now() + # current_time = now.strftime("%H:%M:%S") # Calculate the current privacy cost using the accountant # https://github.com/BorealisAI/private-data-generation/blob/master/models/dp_wgan.py # https://github.com/tensorflow/privacy/tree/master/tutorials/walkthrough - print(current_time, f"Epoch {i+1}, Loss G: {loss_g.detach().cpu(): .4f}," - f"Loss D: {loss_d.detach().cpu(): .4f}", flush=True) + # print(current_time, f"Epoch {i+1}, Loss G: {loss_g.detach().cpu(): .4f}," + # f"Loss D: {loss_d.detach().cpu(): .4f}", flush=True) + + + epoch_iterator.set_description( + description.format(gen=generator_loss, dis=discriminator_loss) + ) + if self.wandb == True : ## Add WB logs @@ -639,29 +671,32 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=None): } wandb.log(metrics) - SAVE_DIR = Path('./data/weights/') SAVE_DIR.mkdir(exist_ok=True, parents=True) - # if i%50 == 0: - # ckpt_file = SAVE_DIR/f"context_model_{i}.pkl" - # ### torch.save(nn_model.state_dict(), ckpt_file) - # self.save(ckpt_file) - - # artifact_name = f"{wandb.run.id}_context_model" - # at = wandb.Artifact(artifact_name, type="model") - # at.add_file(ckpt_file) - # wandb.log_artifact(at, aliases=[f"epoch_{i}"]) - - # syn_data = self.sample(len(train_data))[real_data_columns] - # # real_data.columns = syn_data.columns - # corr_diff_plot = self.corr_plot(real_data, syn_data) - - # wandb.log({ - # "sample_differences_with_realData": wandb.Image(plt) - # # "train_samples": wandb.Table(dataframe=self.sample(len(train_data))) - # ### "train_samples": [wandb.Image(img) for img in samples.split(1)] - # }) + if i%200 == 0: + ckpt_file = SAVE_DIR/f"context_model_{i}.pkl" + ### torch.save(nn_model.state_dict(), ckpt_file) + self.save(ckpt_file) + + artifact_name = f"{wandb.run.id}_context_model" + at = wandb.Artifact(artifact_name, type="model") + at.add_file(ckpt_file) + wandb.log_artifact(at, aliases=[f"epoch_{i}"]) + + syn_data = self.sample(len(train_data))#[real_data_columns] + syn_data_columns = syn_data.columns + # real_data.columns = syn_data.columns + + f, ax = plt.subplots(figsize=(12, 10)) + syn_data[['anchor_age','drug_Dasatinib','systolic']].plot.kde() + #self.corr_plot(real_data, syn_data) + + wandb.log({ + "sample_differences_with_realData": wandb.Image(plt) + # "train_samples": wandb.Table(dataframe=self.sample(len(train_data))) + ### "train_samples": [wandb.Image(img) for img in samples.split(1)] + })