Skip to content

Commit

Permalink
update data sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
Chang Sun committed Dec 4, 2023
1 parent 85a9afe commit a2b5e18
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 34 deletions.
8 changes: 6 additions & 2 deletions src/dp_cgans/data_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ def is_discrete_column(column_info):
self._categories_each_column.append(column_info[0].dim)
self.get_position.append(position_cnt)
position_cnt += 1
max_category = max(self._categories_each_column)

max_category = max(self._categories_each_column, default=0)

self._discrete_column_cond_st = np.zeros(n_discrete_columns, dtype='int32')
self._discrete_column_n_category = np.zeros(
Expand All @@ -116,7 +117,10 @@ def is_discrete_column(column_info):
self._n_categories = sum(self._categories_each_column)

self._categories_each_column = np.array(self._categories_each_column)
second_max_category = np.partition(self._categories_each_column.flatten(), -2)[-2]
if len(self._categories_each_column)>0: # if self._categories_each_column is not empty
second_max_category = np.partition(self._categories_each_column.flatten(), -2)[-2]
else:
second_max_category = 0

self._discrete_pair_cond_st = np.zeros((int(((n_discrete_columns)*(n_discrete_columns-1))/2),int((max_category+1) * (second_max_category+1))),dtype='int32')
self._discrete_pair_n_category = np.zeros(int(((n_discrete_columns)*(n_discrete_columns-1))/2), dtype='int32')
Expand Down
99 changes: 67 additions & 32 deletions src/dp_cgans/synthesizers/dp_cgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,12 @@
from packaging import version
from torch import optim
from torch.nn import BatchNorm1d, Dropout, LeakyReLU, Linear, Module, ReLU, Sequential, functional, BCEWithLogitsLoss, utils
from tqdm import tqdm

from dp_cgans.data_sampler import DataSampler
from dp_cgans.data_transformer import DataTransformer
from dp_cgans.synthesizers.base import BaseSynthesizer

import scipy.stats

######## ADDED ########
from datetime import datetime
from contextlib import redirect_stdout
Expand Down Expand Up @@ -60,9 +59,8 @@ def calc_gradient_penalty(self, real_data, fake_data, device='cpu', pac=10, lamb
create_graph=True, retain_graph=True, only_inputs=True
)[0]

gradient_penalty = ((
gradients.view(-1, pac * real_data.size(1)).norm(2, dim=1) - 1
) ** 2).mean() * lambda_
gradients_view = gradients.view(-1, pac * real_data.size(1)).norm(2, dim=1) - 1
gradient_penalty = ((gradients_view) ** 2).mean() * lambda_

return gradient_penalty

Expand Down Expand Up @@ -198,6 +196,8 @@ def __init__(self, embedding_dim=128, generator_dim=(256, 256), discriminator_di
self._data_sampler = None
self._generator = None
self._discriminator = None

self.loss_values = pd.DataFrame(columns=['Epoch', 'Generator Loss', 'Distriminator Loss'])


@staticmethod
Expand Down Expand Up @@ -245,7 +245,7 @@ def _apply_activate(self, data):
data_t.append(transformed)
st = ed
else:
assert 0
raise ValueError(f'Unexpected activation function {span_info.activation_fn}.')

return torch.cat(data_t, dim=1)

Expand Down Expand Up @@ -463,12 +463,20 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=None):
mean = torch.zeros(self._batch_size, self._embedding_dim, device=self._device)
std = mean + 1

self.loss_values = pd.DataFrame(columns=['Epoch', 'Generator Loss', 'Distriminator Loss'])

epoch_iterator = tqdm(range(epochs), disable=(not self._verbose))
if self._verbose:
description = 'Gen. ({gen:.2f}) | Discrim. ({dis:.2f})'
epoch_iterator.set_description(description.format(gen=0, dis=0))


steps_per_epoch = max(len(train_data) // self._batch_size, 1)
######## ADDED ########
with open('loss_output_%s.txt'%str(epochs), 'w') as f:
with redirect_stdout(f):
######## ADDED ########
for i in range(epochs):
for i in epoch_iterator:
for id_ in range(steps_per_epoch):
for n in range(self._discriminator_steps):

Expand Down Expand Up @@ -611,22 +619,46 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=None):
loss_g = -torch.mean(y_fake) + cross_entropy_pair # + rules_penalty


optimizerG.zero_grad()
optimizerG.zero_grad(set_to_none=False)
loss_g.backward()
optimizerG.step()


generator_loss = loss_g.detach().cpu()
discriminator_loss = loss_d.detach().cpu()


epoch_loss_df = pd.DataFrame({
'Epoch': [i],
'Generator Loss': [generator_loss],
'Discriminator Loss': [discriminator_loss]
})
if not self.loss_values.empty:
self.loss_values = pd.concat(
[self.loss_values, epoch_loss_df]
).reset_index(drop=True)
else:
self.loss_values = epoch_loss_df



if self._verbose:
######## ADDED ########
now = datetime.now()
current_time = now.strftime("%H:%M:%S")
# now = datetime.now()
# current_time = now.strftime("%H:%M:%S")

# Calculate the current privacy cost using the accountant
# https://github.com/BorealisAI/private-data-generation/blob/master/models/dp_wgan.py
# https://github.com/tensorflow/privacy/tree/master/tutorials/walkthrough

print(current_time, f"Epoch {i+1}, Loss G: {loss_g.detach().cpu(): .4f},"
f"Loss D: {loss_d.detach().cpu(): .4f}", flush=True)
# print(current_time, f"Epoch {i+1}, Loss G: {loss_g.detach().cpu(): .4f},"
# f"Loss D: {loss_d.detach().cpu(): .4f}", flush=True)


epoch_iterator.set_description(
description.format(gen=generator_loss, dis=discriminator_loss)
)


if self.wandb == True :
## Add WB logs
Expand All @@ -639,29 +671,32 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=None):
}
wandb.log(metrics)


SAVE_DIR = Path('./data/weights/')
SAVE_DIR.mkdir(exist_ok=True, parents=True)

# if i%50 == 0:
# ckpt_file = SAVE_DIR/f"context_model_{i}.pkl"
# ### torch.save(nn_model.state_dict(), ckpt_file)
# self.save(ckpt_file)

# artifact_name = f"{wandb.run.id}_context_model"
# at = wandb.Artifact(artifact_name, type="model")
# at.add_file(ckpt_file)
# wandb.log_artifact(at, aliases=[f"epoch_{i}"])

# syn_data = self.sample(len(train_data))[real_data_columns]
# # real_data.columns = syn_data.columns
# corr_diff_plot = self.corr_plot(real_data, syn_data)

# wandb.log({
# "sample_differences_with_realData": wandb.Image(plt)
# # "train_samples": wandb.Table(dataframe=self.sample(len(train_data)))
# ### "train_samples": [wandb.Image(img) for img in samples.split(1)]
# })
if i%200 == 0:
ckpt_file = SAVE_DIR/f"context_model_{i}.pkl"
### torch.save(nn_model.state_dict(), ckpt_file)
self.save(ckpt_file)

artifact_name = f"{wandb.run.id}_context_model"
at = wandb.Artifact(artifact_name, type="model")
at.add_file(ckpt_file)
wandb.log_artifact(at, aliases=[f"epoch_{i}"])

syn_data = self.sample(len(train_data))#[real_data_columns]
syn_data_columns = syn_data.columns
# real_data.columns = syn_data.columns

f, ax = plt.subplots(figsize=(12, 10))
syn_data[['anchor_age','drug_Dasatinib','systolic']].plot.kde()
#self.corr_plot(real_data, syn_data)

wandb.log({
"sample_differences_with_realData": wandb.Image(plt)
# "train_samples": wandb.Table(dataframe=self.sample(len(train_data)))
### "train_samples": [wandb.Image(img) for img in samples.split(1)]
})



Expand Down

0 comments on commit a2b5e18

Please sign in to comment.