Skip to content

Commit

Permalink
Fix WaveNet inputs (awslabs#3022)
Browse files Browse the repository at this point in the history
  • Loading branch information
shchur authored and lostella committed Oct 27, 2023
1 parent be52a5f commit 2ddcd05
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/gluonts/torch/model/wavenet/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@

PREDICTION_INPUT_NAMES = [
"feat_static_cat",
"feat_static_real",
"past_target",
"past_observed_values",
"past_time_feat",
Expand Down
4 changes: 4 additions & 0 deletions src/gluonts/torch/model/wavenet/lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def training_step(self, batch, batch_idx: int): # type: ignore
Execute training step.
"""
feat_static_cat = batch["feat_static_cat"]
feat_static_real = batch["feat_static_real"]
past_target = batch["past_target"]
past_observed_values = batch["past_observed_values"]
past_time_feat = batch["past_time_feat"]
Expand All @@ -63,6 +64,7 @@ def training_step(self, batch, batch_idx: int): # type: ignore

train_loss = self.model.loss(
feat_static_cat=feat_static_cat,
feat_static_real=feat_static_real,
past_target=past_target,
past_observed_values=past_observed_values,
past_time_feat=past_time_feat,
Expand All @@ -87,6 +89,7 @@ def validation_step(self, batch, batch_idx: int): # type: ignore
Execute validation step.
"""
feat_static_cat = batch["feat_static_cat"]
feat_static_real = batch["feat_static_real"]
past_target = batch["past_target"]
past_observed_values = batch["past_observed_values"]
past_time_feat = batch["past_time_feat"]
Expand All @@ -97,6 +100,7 @@ def validation_step(self, batch, batch_idx: int): # type: ignore

val_loss = self.model.loss(
feat_static_cat=feat_static_cat,
feat_static_real=feat_static_real,
past_target=past_target,
past_observed_values=past_observed_values,
past_time_feat=past_time_feat,
Expand Down
13 changes: 13 additions & 0 deletions src/gluonts/torch/model/wavenet/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def __init__(
+ num_feat_dynamic_real
+ num_feat_static_real
+ int(use_log_scale_feature) # the log(scale)
+ 1 # for observed value indicator
)
self.use_log_scale_feature = use_log_scale_feature

Expand Down Expand Up @@ -217,6 +218,7 @@ def get_receptive_field(dilation_depth: int, num_stacks: int) -> int:
def get_full_features(
self,
feat_static_cat: torch.Tensor,
feat_static_real: torch.Tensor,
past_observed_values: torch.Tensor,
past_time_feat: torch.Tensor,
future_time_feat: torch.Tensor,
Expand All @@ -230,6 +232,8 @@ def get_full_features(
----------
feat_static_cat
Static categorical features: (batch_size, num_cat_features)
feat_static_real
Static real-valued features: (batch_size, num_feat_static_real)
past_observed_values
Observed value indicator for the past target: (batch_size,
receptive_field)
Expand All @@ -256,6 +260,7 @@ def get_full_features(
static_feat = torch.cat(
[static_feat, torch.log(scale + 1.0)], dim=1
)
static_feat = torch.cat([static_feat, feat_static_real], dim=1)
repeated_static_feat = torch.repeat_interleave(
static_feat[..., None],
self.prediction_length + self.receptive_field,
Expand Down Expand Up @@ -361,6 +366,7 @@ def base_net(
def loss(
self,
feat_static_cat: torch.Tensor,
feat_static_real: torch.Tensor,
past_target: torch.Tensor,
past_observed_values: torch.Tensor,
past_time_feat: torch.Tensor,
Expand All @@ -375,6 +381,8 @@ def loss(
----------
feat_static_cat
Static categorical features: (batch_size, num_cat_features)
feat_static_real
Static real-valued features: (batch_size, num_feat_static_real)
past_target
Past target: (batch_size, receptive_field)
past_observed_values
Expand All @@ -401,6 +409,7 @@ def loss(
full_target = torch.cat([past_target, future_target], dim=-1).long()
full_features = self.get_full_features(
feat_static_cat=feat_static_cat,
feat_static_real=feat_static_real,
past_observed_values=past_observed_values,
past_time_feat=past_time_feat,
future_time_feat=future_time_feat,
Expand Down Expand Up @@ -457,6 +466,7 @@ def _initialize_conv_queues(
def forward(
self,
feat_static_cat: torch.Tensor,
feat_static_real: torch.Tensor,
past_target: torch.Tensor,
past_observed_values: torch.Tensor,
past_time_feat: torch.Tensor,
Expand All @@ -472,6 +482,8 @@ def forward(
----------
feat_static_cat
Static categorical features: (batch_size, num_cat_features)
feat_static_real
Static real-valued features: (batch_size, num_feat_static_real)
past_target
Past target: (batch_size, receptive_field)
past_observed_values
Expand Down Expand Up @@ -508,6 +520,7 @@ def forward(
past_target = past_target.long()
full_features = self.get_full_features(
feat_static_cat=feat_static_cat,
feat_static_real=feat_static_real,
past_observed_values=past_observed_values,
past_time_feat=past_time_feat,
future_time_feat=future_time_feat,
Expand Down

0 comments on commit 2ddcd05

Please sign in to comment.