Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/precommit' into precommit
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Jul 16, 2024
2 parents 2860e8e + feeda94 commit 520892d
Show file tree
Hide file tree
Showing 31 changed files with 388 additions and 174 deletions.
8 changes: 8 additions & 0 deletions docs/getting_started/models.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

Model + Paper | Local/global | Data layout | Architecture/method | Implementation
-------------------------------------------------------------|--------------|--------------------------|---------------------|----------------
PatchTST<br>[Nie et al., 2023][Nie2023] | Global | Univariate | MLP, multi-head attention | [Pytorch][PatchTST_torch]
LagTST<br> | Global | Univariate | MLP, multi-head attention | [Pytorch][LagTST_torch]
DLinear<br>[Zeng et al., 2023][Zeng2023] | Global | Univariate | MLP | [Pytorch][DLinear_torch]
DeepAR<br>[Salinas et al. 2020][Salinas2020] | Global | Univariate | RNN | [MXNet][DeepAR_mx], [PyTorch][DeepAR_torch]
DeepState<br>[Rangapuram et al. 2018][Rangapuram2018] | Global | Univariate | RNN, state-space model | [MXNet][DeepState]
DeepFactor<br>[Wang et al. 2019][Wang2019] | Global | Univariate | RNN, state-space model, Gaussian process | [MXNet][DeepFactor]
Expand Down Expand Up @@ -30,6 +33,8 @@ NPTS | Local | Un

<!-- Links to bibliography -->

[Nie2023]: https://arxiv.org/abs/2211.14730
[Zeng2023]: https://arxiv.org/abs/2205.13504
[Rangapuram2021]: https://proceedings.mlr.press/v139/rangapuram21a.html
[Salinas2020]: https://doi.org/10.1016/j.ijforecast.2019.07.001
[Rangapuram2018]: https://papers.nips.cc/paper/2018/hash/5cf68969fb67aa6082363a6d4e6468e2-Abstract.html
Expand All @@ -52,6 +57,9 @@ NPTS | Local | Un

<!-- Links to code -->

[PatchTST_torch]: https://github.com/awslabs/gluonts/blob/dev/src/gluonts/torch/model/patch_tst/estimator.py
[LagTST_torch]: https://github.com/awslabs/gluonts/blob/dev/src/gluonts/torch/model/lag_tst/estimator.py
[DLinear_torch]: https://github.com/awslabs/gluonts/blob/dev/src/gluonts/torch/model/d_linear/estimator.py
[DeepAR_mx]: https://github.com/awslabs/gluonts/blob/dev/src/gluonts/mx/model/deepar/_estimator.py
[DeepAR_torch]: https://github.com/awslabs/gluonts/blob/dev/src/gluonts/torch/model/deepar/estimator.py
[DeepState]: https://github.com/awslabs/gluonts/blob/dev/src/gluonts/mx/model/deepstate/_estimator.py
Expand Down
20 changes: 10 additions & 10 deletions docs/tutorials/advanced_topics/howto_pytorch_lightning.md.template
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,10 @@ class FeedForwardNetwork(nn.Module):
torch.nn.init.zeros_(lin.bias)
return lin

def forward(self, context):
scale = self.scaling(context)
scaled_context = context / scale
nn_out = self.nn(scaled_context)
def forward(self, past_target):
scale = self.scaling(past_target)
scaled_past_target = past_target / scale
nn_out = self.nn(scaled_past_target)
nn_out_reshaped = nn_out.reshape(-1, self.prediction_length, self.hidden_dimensions[-1])
distr_args = self.args_proj(nn_out_reshaped)
return distr_args, torch.zeros_like(scale), scale
Expand Down Expand Up @@ -143,15 +143,15 @@ class LightningFeedForwardNetwork(FeedForwardNetwork, pl.LightningModule):
super().__init__(*args, **kwargs)

def training_step(self, batch, batch_idx):
context = batch["past_target"]
target = batch["future_target"]
past_target = batch["past_target"]
future_target = batch["future_target"]

assert context.shape[-1] == self.context_length
assert target.shape[-1] == self.prediction_length
assert past_target.shape[-1] == self.context_length
assert future_target.shape[-1] == self.prediction_length

distr_args, loc, scale = self(context)
distr_args, loc, scale = self(past_target)
distr = self.distr_output.distribution(distr_args, loc, scale)
loss = -distr.log_prob(target)
loss = -distr.log_prob(future_target)

return loss.mean()

Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements-extras-anomaly-evaluation.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
numba~=0.51,<0.54
scikit-learn~=0.22
scikit-learn~=1.0
2 changes: 1 addition & 1 deletion requirements/requirements-extras-sagemaker-sdk.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
sagemaker~=2.0
sagemaker~=2.0,>2.214.3
s3fs~=0.6; python_version >= "3.7.0"
s3fs~=0.5; python_version < "3.7.0"
fsspec~=0.8,<0.9; python_version < "3.7.0"
4 changes: 2 additions & 2 deletions requirements/requirements-pytorch.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
torch>=1.9,<3
lightning>=2.0,<2.2
lightning>=2.2.2,<2.4
# Capping `lightning` does not cap `pytorch_lightning`, so we cap manually
pytorch_lightning>=2.0,<2.2
pytorch_lightning>=2.2.2,<2.4
scipy~=1.10; python_version > "3.7.0"
scipy~=1.7.3; python_version <= "3.7.0"
2 changes: 1 addition & 1 deletion requirements/requirements-rotbaum.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
xgboost>=0.90,<2
scikit-learn>=0.22,<2
scikit-learn~=1.0
14 changes: 13 additions & 1 deletion src/gluonts/core/serde/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,15 @@ def encode_partial(v: partial) -> Any:
}


decode_disallow = [
eval,
exec,
compile,
open,
input,
]


def decode(r: Any) -> Any:
"""
Decodes a value from an intermediate representation `r`.
Expand All @@ -312,7 +321,10 @@ def decode(r: Any) -> Any:
kind = r["__kind__"]
cls = cast(Any, locate(r["class"]))

assert cls is not None, f"Can not locate {r['class']}."
if cls is None:
raise ValueError(f"Cannot locate {r['class']}.")
if cls in decode_disallow:
raise ValueError(f"{r['class']} cannot be run.")

if kind == Kind.Type:
return cls
Expand Down
10 changes: 5 additions & 5 deletions src/gluonts/ext/rotbaum/_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

import concurrent.futures
import logging
import pickle
from itertools import chain
from typing import Iterator, List, Optional, Any, Dict
from toolz import first
Expand All @@ -24,6 +23,7 @@
from itertools import compress

from gluonts.core.component import validated
from gluonts.core.serde import dump_json, load_json
from gluonts.dataset.common import Dataset
from gluonts.dataset.util import forecast_start
from gluonts.model.forecast import Forecast
Expand Down Expand Up @@ -355,8 +355,8 @@ class name, version information and constructor arguments.
generated when pickling the TreePredictor.
"""
super().serialize(path)
with (path / "predictor.pkl").open("wb") as f:
pickle.dump(self.model_list, f)
with (path / "model_list.json").open("w") as fp:
print(dump_json(self.model_list), file=fp)

@classmethod
def deserialize(cls, path: Path, **kwargs: Any) -> "TreePredictor":
Expand All @@ -369,8 +369,8 @@ def deserialize(cls, path: Path, **kwargs: Any) -> "TreePredictor":

predictor = super().deserialize(path)
assert isinstance(predictor, cls)
with (path / "predictor.pkl").open("rb") as f:
predictor.model_list = pickle.load(f)
with (path / "model_list.json").open("r") as fp:
predictor.model_list = load_json(fp.read())
return predictor

def explain(
Expand Down
22 changes: 18 additions & 4 deletions src/gluonts/model/forecast_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,18 @@ def make_distribution_forecast(distr, *args, **kwargs) -> Forecast:
raise NotImplementedError


def make_predictions(prediction_net, inputs: dict):
try:
# Feed inputs as positional arguments for MXNet block predictors
import mxnet as mx

if isinstance(prediction_net, mx.gluon.Block):
return prediction_net(*inputs.values())
except ImportError:
pass
return prediction_net(**inputs)


class ForecastGenerator:
"""
Classes used to bring the output of a network into a class.
Expand Down Expand Up @@ -115,7 +127,7 @@ def __call__(
) -> Iterator[Forecast]:
for batch in inference_data_loader:
inputs = select(input_names, batch, ignore_missing=True)
(outputs,), loc, scale = prediction_net(*inputs.values())
(outputs,), loc, scale = make_predictions(prediction_net, inputs)
outputs = to_numpy(outputs)
if scale is not None:
outputs = outputs * to_numpy(scale[..., None])
Expand Down Expand Up @@ -159,14 +171,16 @@ def __call__(
) -> Iterator[Forecast]:
for batch in inference_data_loader:
inputs = select(input_names, batch, ignore_missing=True)
outputs = to_numpy(prediction_net(*inputs.values()))
outputs = to_numpy(make_predictions(prediction_net, inputs))
if output_transform is not None:
outputs = output_transform(batch, outputs)
if num_samples:
num_collected_samples = outputs[0].shape[0]
collected_samples = [outputs]
while num_collected_samples < num_samples:
outputs = to_numpy(prediction_net(*inputs.values()))
outputs = to_numpy(
make_predictions(prediction_net, inputs)
)
if output_transform is not None:
outputs = output_transform(batch, outputs)
collected_samples.append(outputs)
Expand Down Expand Up @@ -209,7 +223,7 @@ def __call__(
) -> Iterator[Forecast]:
for batch in inference_data_loader:
inputs = select(input_names, batch, ignore_missing=True)
outputs = prediction_net(*inputs.values())
outputs = make_predictions(prediction_net, inputs)

if output_transform:
log_once(OUTPUT_TRANSFORM_NOT_SUPPORTED_MSG)
Expand Down
4 changes: 2 additions & 2 deletions src/gluonts/nursery/few_shot_prediction/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ gluonts = {git = "https://github.com/awslabs/gluon-ts.git"}
pandas = "^1.3.1"
python = "^3.8,<3.10"
pytorch-lightning = "^1.4.4"
sagemaker = "^2.40.0,<2.41.0"
sagemaker = "^2.218.0"
scikit-learn = "^1.4.0"
torch = "^1.9.0"
sagemaker-training = "^3.9.2"
Expand All @@ -27,7 +27,7 @@ catch22 = "^0.2.0"
seaborn = "^0.11.2"

[tool.poetry.dev-dependencies]
black = "^21.7b0"
black = "^24.3.0"
isort = "^5.9.3"
jupyter = "^1.0.0"
pylint = "^2.10.2"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,10 @@
TransformedImplicitQuantile,
)
from gluonts.core.component import validated
from gluonts.torch.modules.distribution_output import (
DistributionOutput,
LambdaLayer,
PtArgProj,
)
from gluonts.torch.distributions.distribution_output import DistributionOutput
from gluonts.torch.modules.lambda_layer import LambdaLayer
from gluonts.torch.distributions.output import PtArgProj

from pts.modules.iqn_modules import ImplicitQuantileModule


Expand Down
4 changes: 2 additions & 2 deletions src/gluonts/nursery/tsbench/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ plotly = "^5.3.1"
pyarrow = "^14.0.1"
pydantic = "^1.8.2"
pygmo = "^2.16.1"
pymongo = "^3.12.0"
pymongo = "^4.6.3"
pystan = "^2.0.0"
python = ">=3.8,<3.9"
pytorch-lightning = "^1.5.0"
Expand All @@ -43,7 +43,7 @@ ujson = "^5.1.0"
xgboost = "^1.4.1"

[tool.poetry.dev-dependencies]
black = "^21.5b1"
black = "^24.3.0"
isort = "^5.8.0"
jupyter = "^1.0.0"
mypy = "^0.812"
Expand Down
26 changes: 13 additions & 13 deletions src/gluonts/time_feature/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from packaging.version import Version
from typing import Any, Callable, Dict, List

import numpy as np
Expand Down Expand Up @@ -196,7 +197,10 @@ def norm_freq_str(freq_str: str) -> str:
# Note: Secondly ("S") frequency exists, where we don't want to remove the
# "S"!
if len(base_freq) >= 2 and base_freq.endswith("S"):
return base_freq[:-1]
base_freq = base_freq[:-1]
# In pandas >= 2.2, period end frequencies have been renamed, e.g. "M" -> "ME"
if Version(pd.__version__) >= Version("2.2.0"):
base_freq += "E"

return base_freq

Expand Down Expand Up @@ -252,17 +256,13 @@ def time_features_from_frequency_str(freq_str: str) -> List[TimeFeature]:
Unsupported frequency {freq_str}
The following frequencies are supported:
Y - yearly
alias: A
Q - quarterly
M - monthly
W - weekly
D - daily
B - business days
H - hourly
T - minutely
alias: min
S - secondly
"""

for offset_cls in features_by_offsets:
offset = offset_cls()
supported_freq_msg += (
f"\t{offset.freqstr.split('-')[0]} - {offset_cls.__name__}"
)

raise RuntimeError(supported_freq_msg)
1 change: 1 addition & 0 deletions src/gluonts/time_feature/seasonality.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
"ME": 12,
"B": 5,
"Q": 4,
"QE": 4,
}


Expand Down
8 changes: 0 additions & 8 deletions src/gluonts/torch/distributions/distribution_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,6 @@ def loss(
nll = nll * (variance.detach() ** self.beta)
return nll

@property
def event_shape(self) -> Tuple:
r"""
Shape of each individual event contemplated by the distributions that
this object constructs.
"""
raise NotImplementedError()

@property
def event_dim(self) -> int:
r"""
Expand Down
7 changes: 7 additions & 0 deletions src/gluonts/torch/distributions/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,13 @@ def loss(
"""
raise NotImplementedError()

@property
def event_shape(self) -> Tuple:
r"""
Shape of each individual event compatible with the output object.
"""
raise NotImplementedError()

@property
def forecast_generator(self) -> ForecastGenerator:
raise NotImplementedError()
Expand Down
4 changes: 4 additions & 0 deletions src/gluonts/torch/distributions/quantile_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ def __init__(self, quantiles: List[float]) -> None:
def forecast_generator(self) -> ForecastGenerator:
return QuantileForecastGenerator(quantiles=self.quantiles)

@property
def event_shape(self) -> Tuple:
return ()

def domain_map(self, *args: torch.Tensor) -> Tuple[torch.Tensor, ...]:
return args

Expand Down
Loading

0 comments on commit 520892d

Please sign in to comment.