From 39440acd9094e95f86dffadeb60e1dff61b09fc9 Mon Sep 17 00:00:00 2001 From: Oleksandr Shchur Date: Fri, 31 May 2024 14:25:17 +0000 Subject: [PATCH 1/3] Fix doc build --- src/gluonts/model/forecast_generator.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/gluonts/model/forecast_generator.py b/src/gluonts/model/forecast_generator.py index 33b0320808..b9ff10c8dc 100644 --- a/src/gluonts/model/forecast_generator.py +++ b/src/gluonts/model/forecast_generator.py @@ -83,12 +83,15 @@ def make_distribution_forecast(distr, *args, **kwargs) -> Forecast: def make_predictions(prediction_net, inputs: dict): - # MXNet predictors only support positional arguments - class_name = prediction_net.__class__.__module__ - if class_name.startswith("gluonts.mx") or class_name.startswith("mxnet"): - return prediction_net(*inputs.values()) - else: - return prediction_net(**inputs) + try: + # Feed inputs as positional arguments for MXNet predictors + import mxnet as mx + + if isinstance(prediction_net, (mx.gluon.Block)): + return prediction_net(*inputs.values()) + except ImportError: + pass + return prediction_net(**inputs) class ForecastGenerator: From 839907bc29a23d00390c08a8168e1755642d477b Mon Sep 17 00:00:00 2001 From: Oleksandr Shchur Date: Fri, 31 May 2024 14:27:36 +0000 Subject: [PATCH 2/3] Fix parentheses --- src/gluonts/model/forecast_generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gluonts/model/forecast_generator.py b/src/gluonts/model/forecast_generator.py index b9ff10c8dc..2db978ea72 100644 --- a/src/gluonts/model/forecast_generator.py +++ b/src/gluonts/model/forecast_generator.py @@ -87,7 +87,7 @@ def make_predictions(prediction_net, inputs: dict): # Feed inputs as positional arguments for MXNet predictors import mxnet as mx - if isinstance(prediction_net, (mx.gluon.Block)): + if isinstance(prediction_net, mx.gluon.Block): return prediction_net(*inputs.values()) except ImportError: pass From 3ebc2b423c4b828e8bc98fe04eab81a962d25f06 Mon Sep 17 00:00:00 2001 From: Oleksandr Shchur Date: Fri, 31 May 2024 14:31:02 +0000 Subject: [PATCH 3/3] Trigger build --- src/gluonts/model/forecast_generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gluonts/model/forecast_generator.py b/src/gluonts/model/forecast_generator.py index 2db978ea72..0148a8e1e6 100644 --- a/src/gluonts/model/forecast_generator.py +++ b/src/gluonts/model/forecast_generator.py @@ -84,7 +84,7 @@ def make_distribution_forecast(distr, *args, **kwargs) -> Forecast: def make_predictions(prediction_net, inputs: dict): try: - # Feed inputs as positional arguments for MXNet predictors + # Feed inputs as positional arguments for MXNet block predictors import mxnet as mx if isinstance(prediction_net, mx.gluon.Block):