Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WaveNet model crashes if num_feat_static_real > 0 #3021

Closed
shchur opened this issue Oct 20, 2023 · 0 comments · Fixed by #3022
Closed

WaveNet model crashes if num_feat_static_real > 0 #3021

shchur opened this issue Oct 20, 2023 · 0 comments · Fixed by #3022
Labels
bug Something isn't working torch This concerns the PyTorch side of GluonTS
Milestone

Comments

@shchur
Copy link
Contributor

shchur commented Oct 20, 2023

Description

If I create WaveNetEstimator with num_feat_static_real > 0, the model will crash at the forward pass. This happens because num_feat_static_real is used when defining the expected input shape, but actually isn't received by the model.

To Reproduce

import numpy as np
import pandas as pd
from gluonts.torch.model.wavenet import WaveNetEstimator

N = 50
freq = "D"
data = [
    {
        "target": np.random.rand(N),
        "feat_static_real": [5.0] * 5,
        "start": pd.Period("2020-01-01", freq=freq),
    }
]

model = WaveNetEstimator(freq=freq, prediction_length=5, num_feat_static_real=5, trainer_kwargs={"accelerator": "cpu"})
model.train(data)

Error message or code output

(Paste the complete error message, including stack trace, or the undesired output that the above snippet produces.)

RuntimeError: Given groups=1, weight of size [24, 39, 1], expected input[32, 35, 20] to have 39 channels, but got 35 channels instead

Environment

  • Operating system:
  • Python version:
  • GluonTS version:
  • MXNet version:

(Add as much information about your environment as possible, e.g. dependencies versions.)

@shchur shchur added the bug Something isn't working label Oct 20, 2023
@shchur shchur added this to the v0.14 milestone Oct 20, 2023
@lostella lostella added the torch This concerns the PyTorch side of GluonTS label Oct 20, 2023
@shchur shchur linked a pull request Oct 20, 2023 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working torch This concerns the PyTorch side of GluonTS
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants