Skip to content

Commit

Permalink
PR fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
MarIniOnz committed Oct 10, 2024
1 parent 3ce6cb2 commit 584c868
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 37 deletions.
70 changes: 53 additions & 17 deletions medmodels/data_synthesis/mtgan/model/critic/critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@


class Critic(nn.Module):
"""Critic for MTGAN.
The Critic tries to distinguish between real data and synthetic data generated
by the Generator. The Critic is trained to maximize the Wasserstein distance
between the real and synthetic data distributions.
"""

real_gru: RealGRU
device: torch.device

Expand All @@ -40,10 +47,10 @@ def __init__(
between the real and synthetic data distributions.
Args:
real_gru (RealGRU): Real GRU model
total_number_of_concepts (int): Size of input features (number of codes)
hyperparameters (TrainingHyperparameters): Hyperparameters for training
device (torch.device): Device to run the model on
real_gru (RealGRU): Real GRU model.
total_number_of_concepts (int): Size of input features (number of codes).
hyperparameters (TrainingHyperparameters): Hyperparameters for training.
device (torch.device): Device to run the model on.
"""
super().__init__()

Expand All @@ -66,7 +73,7 @@ def __init__(
gamma=hyperparameters["decay_rate"],
)
self.loss_function = CriticLoss(
self, lambda_=hyperparameters["lambda_gradient"]
self, lambda_gradient=hyperparameters["lambda_gradient"]
)

self.critic_layers = nn.Sequential(
Expand All @@ -92,11 +99,11 @@ def forward(
Args:
data (torch.Tensor): input data, of shape (batch size, maximum number of
windows, total number of concepts)
windows, total number of concepts).
hidden_states (torch.Tensor): hidden states of the RealGRU, of shape
(batch size, maximum number of windows, generator hidden dimension)
(batch size, maximum number of windows, generator hidden dimension).
number_of_windows_per_patient (torch.Tensor): number of windows each
patient has, of shape (batch size)
patient has, of shape (batch size).
Returns:
torch.Tensor: scores of the critic. If the score is high, the data is considered to be
Expand Down Expand Up @@ -128,18 +135,18 @@ def _train_critic_iteration(
Args:
real_data (torch.Tensor): Real data, of shape (batch size, maximum number of
windows, total number of concepts)
windows, total number of concepts).
number_of_windows_per_patient (torch.Tensor): number of windows each
patient has, of shape (batch size)
patient has, of shape (batch size).
generator (Generator): Generator of the MTGAN, counterpart of the Critic in
the GAN architecture
the GAN architecture.
target_concepts (torch.Tensor): Array of concepts chosen for each training
batch, drawn uniformly from all concepts to ensure all concepts are
included in the training and synthesis process. Shape: batch size.
Returns:
Tuple[float, float]: Critic loss and Wasserstein distance between real and
synthetic data for a single iteration
synthetic data for a single iteration.
"""
real_hiddens = self.real_gru.calculate_hidden(
real_data, number_of_windows_per_patient
Expand Down Expand Up @@ -176,36 +183,65 @@ def train_critic(
Args:
real_data (torch.Tensor): Real data, of shape (batch size, maximum number of
windows, total number of concepts)
windows, total number of concepts).
number_of_windows_per_patient (torch.Tensor): number of windows each
patient has, of shape (batch size)
patient has, of shape (batch size).
generator (Generator): Generator of the MTGAN, counterpart of the Critic in
the GAN architecture
the GAN architecture.
target_concepts (torch.Tensor): Array of concepts chosen for each training
batch, drawn uniformly from all concepts to ensure all concepts are
included in the training and synthesis process. Shape: batch size.
Returns:
Tuple[float, float]: Critic loss and Wasserstein distance between real and
synthetic data
synthetic data.
"""
self.train()
generator.eval()

loss = 0
wasserstein_distance = 0

for _ in range(self.critic_iterations):
loss_iteration, wasserstein_distance_iteration = (
self._train_critic_iteration(
real_data, number_of_windows_per_patient, generator, target_concepts
)
)

loss += loss_iteration
wasserstein_distance += wasserstein_distance_iteration

loss /= self.critic_iterations
wasserstein_distance /= self.critic_iterations

self.scheduler.step()

return loss, wasserstein_distance

def evaluate(self, data_loader: MTGANDataLoader) -> float: ...
def evaluate(self, test_data: MTGANDataLoader) -> float:
"""Evaluate the Critic on test data.
Args:
test_data (MTGANDataLoader): Data loader for the test data.
Returns:
float: Average Wasserstein distance between real and synthetic data.
"""
self.eval()

with torch.no_grad():
loss = 0

for data, number_windows_per_patient in test_data:
hidden_states = self.real_gru.calculate_hidden_states(
data, number_windows_per_patient
)

loss += (
self(data, hidden_states, number_windows_per_patient).mean().item()
)

loss = -loss / len(test_data)

return loss
47 changes: 27 additions & 20 deletions medmodels/data_synthesis/mtgan/model/critic/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,16 @@
class CriticLoss(nn.Module):
"""Critic Loss: Wasserstein-Loss with gradient penalty for Critic."""

def __init__(self, critic: Critic, lambda_: float) -> None:
def __init__(self, critic: Critic, lambda_gradient: float) -> None:
"""Constructor for the Critic Wasserstein-Loss with gradient penalty.
Args:
critic (Critic): Critic
lambda_ (int): Gradient penalty coefficient.
critic (Critic): Critic model.
lambda_gradient (int): Gradient penalty coefficient.
"""
super().__init__()
self.critic = critic
self.lambda_ = lambda_
self.lambda_gradient = lambda_gradient

def forward(
self,
Expand All @@ -37,31 +37,38 @@ def forward(
Args:
real_data (torch.Tensor): Real data of shape (batch size, maximum number
of windows, total number of concepts)
of windows, total number of concepts).
real_hiddens (torch.Tensor): Real hidden states from RealGRU of shape
(batch size, maximum number of windows, generator hidden dimension)
(batch size, maximum number of windows, generator hidden dimension).
synthetic_data (torch.Tensor): Synthetic data generated, same shape as
real data
real data.
synthetic_hiddens (torch.Tensor): Synthetic hidden states from SyntheticGRU
with the same shape as real hiddens
number_windows_per_patient (torch.Tensor): number of windows per patient
with the same shape as real hiddens.
number_windows_per_patient (torch.Tensor): number of windows per patient.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Critic loss and Wasserstein distance
Tuple[torch.Tensor, torch.Tensor]: Critic loss and Wasserstein distance.
"""
critic_real = self.critic(real_data, real_hiddens, number_windows_per_patient)
critic_synthetic = self.critic(
real_scores_critic = self.critic(
real_data, real_hiddens, number_windows_per_patient
)
synthetic_scores_critic = self.critic(
synthetic_data, synthetic_hiddens, number_windows_per_patient
)

gradient_penalty = self.compute_gradient_penalty(
real_data,
real_hiddens,
synthetic_data,
synthetic_hiddens,
number_windows_per_patient,
)
wasserstein_distance = critic_real.mean() - critic_synthetic.mean()

wasserstein_distance = (
real_scores_critic.mean() - synthetic_scores_critic.mean()
)
critic_loss = -wasserstein_distance + gradient_penalty

return critic_loss, wasserstein_distance

def compute_gradient_penalty(
Expand All @@ -76,17 +83,17 @@ def compute_gradient_penalty(
Args:
real_data (torch.Tensor): Real data of shape (batch size, maximum number
of windows, total number of concepts)
of windows, total number of concepts).
real_hiddens (torch.Tensor): Real hidden states from RealGRU of shape
(batch size, maximum number of windows, generator hidden dimension)
(batch size, maximum number of windows, generator hidden dimension).
synthetic_data (torch.Tensor): Synthetic data generated, same shape as
real data
real data.
synthetic_hiddens (torch.Tensor): Synthetic hidden states from SyntheticGRU
with the same shape as real hiddens
number_windows_per_patient (torch.Tensor): number of windows per patient
with the same shape as real hiddens.
number_windows_per_patient (torch.Tensor): number of windows per patient.
Returns:
torch.Tensor: gradient penalty
torch.Tensor: gradient penalty.
"""
batch_size = len(real_data)

Expand Down Expand Up @@ -120,4 +127,4 @@ def compute_gradient_penalty(
gradient_penalty = (gradients.norm(2, dim=1) - 1) ** 2

# Return the scaled gradient penalty
return gradient_penalty.mean() * self.lambda_
return gradient_penalty.mean() * self.lambda_gradient

0 comments on commit 584c868

Please sign in to comment.