From 1f00c9a38d52180677946156890569029f3d187b Mon Sep 17 00:00:00 2001 From: Wesley Gifford <79663411+wgifford@users.noreply.github.com> Date: Tue, 26 Nov 2024 13:34:57 -0500 Subject: [PATCH 01/23] relax parameters strictness --- services/boilerplate/inference_payloads.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/services/boilerplate/inference_payloads.py b/services/boilerplate/inference_payloads.py index 25436f98..fcab7934 100644 --- a/services/boilerplate/inference_payloads.py +++ b/services/boilerplate/inference_payloads.py @@ -100,13 +100,10 @@ class ForecastingMetadataInput(BaseMetadataInput): ) -class BaseParameters(BaseModel): - model_config = ConfigDict(extra="forbid", protected_namespaces=()) - +class BaseParameters(BaseModel): ... -class ForecastingParameters(BaseModel): - model_config = ConfigDict(extra="forbid", protected_namespaces=()) +class ForecastingParameters(BaseParameters): prediction_length: Optional[int] = Field( description="The prediction length for the forecast." " The service will return this many periods beyond the last" From 27f6a3c03f5c9b8b89aa76d8fc02def0c76ffbb4 Mon Sep 17 00:00:00 2001 From: Wesley Gifford <79663411+wgifford@users.noreply.github.com> Date: Tue, 26 Nov 2024 14:10:45 -0500 Subject: [PATCH 02/23] allow extra --- services/boilerplate/inference_payloads.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/services/boilerplate/inference_payloads.py b/services/boilerplate/inference_payloads.py index fcab7934..92a084d5 100644 --- a/services/boilerplate/inference_payloads.py +++ b/services/boilerplate/inference_payloads.py @@ -100,7 +100,8 @@ class ForecastingMetadataInput(BaseMetadataInput): ) -class BaseParameters(BaseModel): ... +class BaseParameters(BaseModel): + model_config = ConfigDict(extra="allow", protected_namespaces=()) class ForecastingParameters(BaseParameters): From 6dd4d4d2ec64d94d4d9d0e30ac4cd01a7810cf7d Mon Sep 17 00:00:00 2001 From: Wesley Gifford <79663411+wgifford@users.noreply.github.com> Date: Mon, 2 Dec 2024 19:09:38 -0500 Subject: [PATCH 03/23] clarify citations --- wiki.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wiki.md b/wiki.md index 5cb003db..ad2f1faa 100644 --- a/wiki.md +++ b/wiki.md @@ -17,7 +17,7 @@ In this section, we highlight the papers, blogs, pre-trained models, and open-so ## Publications 4 KDD, 1 ICLR, 2 AAAI, 1 ICML. -🚀 _**Total citations: 1700**_ (as of 26 Aug 2024). 🚀 +🚀 _**Total citations (Google Scholar): 1700**_ (as of 26 Aug 2024). 🚀 1. **TST:** Zerveas, G., Jayaraman, S., Patel, D., Bhamidipaty, A., & Eickhoff, C. [_A transformer-based framework for multivariate time series representation learning._](https://arxiv.org/abs/2010.02803) In KDD 2021. **(citations: 840)** From 2c7d487fb83c80ce29c144f5bdcbaa13eb668eb0 Mon Sep 17 00:00:00 2001 From: Arindam Jati Date: Tue, 3 Dec 2024 12:20:26 -0500 Subject: [PATCH 04/23] gluonts data wrapper, and ttm gluonts predictor --- .../tinytimemixer/test_gluonts_predictor.py | 52 ++ tests/toolkit/test_get_model.py | 6 + tests/toolkit/test_gluonts_data_wrapper.py | 124 +++++ .../models/tinytimemixer/gluonts/__init__.py | 5 + .../gluonts/ttm_gluonts_predictor.py | 522 ++++++++++++++++++ tsfm_public/toolkit/ddp_utils.py | 24 + tsfm_public/toolkit/get_model.py | 18 +- tsfm_public/toolkit/gluonts_data_wrapper.py | 245 ++++++++ tsfm_public/toolkit/visualization.py | 2 +- 9 files changed, 994 insertions(+), 4 deletions(-) create mode 100644 tests/models/tinytimemixer/test_gluonts_predictor.py create mode 100644 tests/toolkit/test_gluonts_data_wrapper.py create mode 100644 tsfm_public/models/tinytimemixer/gluonts/__init__.py create mode 100644 tsfm_public/models/tinytimemixer/gluonts/ttm_gluonts_predictor.py create mode 100644 tsfm_public/toolkit/ddp_utils.py create mode 100644 tsfm_public/toolkit/gluonts_data_wrapper.py diff --git a/tests/models/tinytimemixer/test_gluonts_predictor.py b/tests/models/tinytimemixer/test_gluonts_predictor.py new file mode 100644 index 00000000..6e8dbb56 --- /dev/null +++ b/tests/models/tinytimemixer/test_gluonts_predictor.py @@ -0,0 +1,52 @@ +# Copyright contributors to the TSFM project +# + +"""Tests get_model""" + +import numpy as np +import pandas as pd +import pytest +from gluonts.dataset.common import ListDataset + +from tsfm_public.models.tinytimemixer.gluonts import TTMGluonTSPredictor + + +@pytest.fixture(scope="module") +def gluonts_data_with_nan(): + # Step 1: Define the multivariate time series data + num_time_series = 3 # Number of time series + num_variables = 2 # Number of variables (dimensions) per time series + + # Create random multivariate time series data + time_series_data = [ + { + "item_id": f"ts{i+1}", + "start": pd.Timestamp("2024-01-01"), # Start time for each series + "target": np.concatenate( + ( + np.array([[np.nan, np.nan, np.nan, np.nan], [0, 1, np.nan, 2]]), + np.random.rand(num_variables, 600), + np.array([[np.nan, np.nan, np.nan, np.nan], [np.nan, 1, np.nan, 2]]), + np.random.rand(num_variables, 4), + ), + axis=1, + ), # 2D array: (num_variables, length) + } + for i in range(num_time_series) + ] + + # Step 2: Create the ListDataset + freq = "D" # Daily frequency + dataset = ListDataset( + time_series_data, + freq=freq, + one_dim_target=False, + ) + return dataset + + +def test_ttm_gluonts_predictor(gluonts_data_with_nan): + dataset = gluonts_data_with_nan + predictor = TTMGluonTSPredictor(context_length=512, prediction_length=96) + forecasts = predictor.predict(dataset) + assert forecasts[0].samples.shape == (1, 96, 2) diff --git a/tests/toolkit/test_get_model.py b/tests/toolkit/test_get_model.py index 8292a307..9d377169 100644 --- a/tests/toolkit/test_get_model.py +++ b/tests/toolkit/test_get_model.py @@ -66,3 +66,9 @@ def test_get_model(): model = get_model(model_path=mp, context_length=cl, prediction_length=fl) assert model.config.prediction_length == fl assert model.config.context_length == cl + + mp = "ibm/ttm-research-r2" + for cl in range(512, 5000, 500): + for fl in range(1, 720, 90): + model = get_model(model_path=mp, context_length=cl, prediction_length=fl) + assert model.config.prediction_filter_length == fl diff --git a/tests/toolkit/test_gluonts_data_wrapper.py b/tests/toolkit/test_gluonts_data_wrapper.py new file mode 100644 index 00000000..8f74855f --- /dev/null +++ b/tests/toolkit/test_gluonts_data_wrapper.py @@ -0,0 +1,124 @@ +# Copyright contributors to the TSFM project +# + +"""Tests get_model""" + +import numpy as np +import pandas as pd +import pytest +from gluonts.dataset.common import ListDataset +from gluonts.dataset.split import split + +from tsfm_public.toolkit.gluonts_data_wrapper import ( + StandardScalingGluonTSDataset, + TorchDatasetFromGluonTSTestDataset, + TorchDatasetFromGluonTSTrainingDataset, +) + + +@pytest.fixture(scope="module") +def gluonts_data(): + # Step 1: Define the multivariate time series data + num_time_series = 3 # Number of time series + length = 50 # Length of each time series + num_variables = 2 # Number of variables (dimensions) per time series + + # Create random multivariate time series data + time_series_data = [ + { + "item_id": f"ts{i+1}", + "start": pd.Timestamp("2024-01-01"), # Start time for each series + "target": np.random.rand(num_variables, length), # 2D array: (num_variables, length) + } + for i in range(num_time_series) + ] + + # Step 2: Create the ListDataset + freq = "D" # Daily frequency + dataset = ListDataset( + time_series_data, + freq=freq, + one_dim_target=False, + ) + return dataset + + +def test_gluonts_standard_scaling(gluonts_data): + dataset = gluonts_data + + # Split the dataset into train and test + prediction_length = 10 + train_dataset, test_template = split(dataset, offset=-prediction_length) + + # Test shapes + for entry in train_dataset: + assert entry["target"].shape == (2, 40) + + test_dataset = test_template.generate_instances( + prediction_length=prediction_length, + ) + test_dataset_input = test_dataset.input + test_dataset_label = test_dataset.label + # Test shapes + for entry in test_dataset_input: + assert entry["target"].shape == (2, 40) + for entry in test_dataset_label: + assert entry["target"].shape == (2, 10) + + # Test scaler + scaler = StandardScalingGluonTSDataset() + scaler.fit(train_dataset) + train_dataset_scaled = scaler.transform(train_dataset) + test_dataset_scaled = scaler.transform(test_dataset_input) + + # Test scaling + for entry in train_dataset_scaled: + np.testing.assert_almost_equal(entry["target"].mean(axis=1), np.array([0.0, 0.0]), decimal=4) + np.testing.assert_almost_equal(entry["target"].std(axis=1), np.array([1.0, 1.0]), decimal=4) + + for entry in test_dataset_scaled: + np.testing.assert_almost_equal(entry["target"].mean(axis=1), np.array([0.0, 0.0]), decimal=4) + np.testing.assert_almost_equal(entry["target"].std(axis=1), np.array([1.0, 1.0]), decimal=4) + + # inverse + test_label_scaled = scaler.transform(test_dataset_label) + Y = [] + for entry in test_label_scaled: + Y.append(entry["target"].T) + Y = np.array(Y) + Y_inv = scaler.inverse_transform(Y) + + Y_org = [] + for entry in test_dataset_label: + Y_org.append(entry["target"].T) + Y_org = np.array(Y_org) + + np.testing.assert_almost_equal(Y_inv.mean(), Y_org.mean(), decimal=4) + + +def test_pytorch_data_wrappers(gluonts_data): + dataset = gluonts_data + + # Split the dataset into train and test + prediction_length = 10 + train_dataset, test_template = split(dataset, offset=-prediction_length) + test_dataset = test_template.generate_instances( + prediction_length=prediction_length, + ) + test_dataset_input = test_dataset.input + test_dataset_label = test_dataset.label + + torch_train_dset = TorchDatasetFromGluonTSTrainingDataset(train_dataset, seq_len=20, forecast_len=5) + assert torch_train_dset[1]["past_values"].shape == (20, 2) + assert torch_train_dset[1]["future_values"].shape == (5, 2) + + torch_train_dset = TorchDatasetFromGluonTSTrainingDataset(train_dataset, seq_len=35, forecast_len=5) + assert torch_train_dset[1]["past_values"].shape == (35, 2) + assert torch_train_dset[1]["future_values"].shape == (5, 2) + assert len(torch_train_dset) == 3 + + torch_test_dset = TorchDatasetFromGluonTSTestDataset( + gluon_test_input=test_dataset_input, gluon_test_label=test_dataset_label, seq_len=20, forecast_len=5 + ) + assert torch_test_dset[0]["past_values"].shape == (20, 2) + assert torch_test_dset[0]["future_values"].shape == (5, 2) diff --git a/tsfm_public/models/tinytimemixer/gluonts/__init__.py b/tsfm_public/models/tinytimemixer/gluonts/__init__.py new file mode 100644 index 00000000..67f0d162 --- /dev/null +++ b/tsfm_public/models/tinytimemixer/gluonts/__init__.py @@ -0,0 +1,5 @@ +from tsfm_public.models.tinytimemixer.gluonts.ttm_gluonts_predictor import ( + TTM_AVAILABLE_CONTEXTS, + TTM_MAX_FORECAST_HORIZON, + TTMGluonTSPredictor, +) diff --git a/tsfm_public/models/tinytimemixer/gluonts/ttm_gluonts_predictor.py b/tsfm_public/models/tinytimemixer/gluonts/ttm_gluonts_predictor.py new file mode 100644 index 00000000..1f94a621 --- /dev/null +++ b/tsfm_public/models/tinytimemixer/gluonts/ttm_gluonts_predictor.py @@ -0,0 +1,522 @@ +# Copyright contributors to the TSFM project +# +"""Tools for building TTM Predictor that works with GluonTS datasets""" + +import copy +import math +import os +import tempfile +from typing import List + +import numpy as np +import pandas as pd +import torch +from gluonts.dataset.split import InputDataset, LabelDataset, TrainingDataset +from gluonts.itertools import batcher +from gluonts.model.forecast import SampleForecast +from torch.optim import AdamW +from torch.optim.lr_scheduler import OneCycleLR +from torch.utils.data import ConcatDataset, Subset +from tqdm.auto import tqdm +from transformers import EarlyStoppingCallback, Trainer, TrainingArguments +from transformers.integrations import INTEGRATION_TO_CALLBACK +from transformers.utils import logging + +from tsfm_public import ( + TrackingCallback, + count_parameters, +) +from tsfm_public.toolkit.get_model import get_model +from tsfm_public.toolkit.gluonts_data_wrapper import ( + StandardScalingGluonTSDataset, + TorchDatasetFromGluonTSTestDataset, + TorchDatasetFromGluonTSTrainingDataset, + impute_series, +) +from tsfm_public.toolkit.lr_finder import optimal_lr_finder +from tsfm_public.toolkit.visualization import plot_predictions + + +logger = logging.get_logger(__name__) + +# TTM Constants +TTM_MAX_FORECAST_HORIZON = 720 +TTM_AVAILABLE_CONTEXTS = [1536, 1024, 512] + +# Fewshot max allowed number of samples +# For example, if 5% few-shot for a dataset exceeds this number, +# this `FEWSHOT_MAX_NUM_SAMPLES` upper bound will be used. +FEWSHOT_MAX_NUM_SAMPLES = 500_000 + + +class TTMGluonTSPredictor: + """Wrapper to TTM that can be directly trained, validated, and tested with GluonTS datasets.""" + + def __init__( + self, + context_length: int, + prediction_length: int, + model_path: str = "ibm-granite/granite-timeseries-ttm-r2", + test_data_label: LabelDataset = None, # provide this for plotting + scale: bool = False, + random_seed: int = 42, + term: str = None, + ds_name: str = None, + out_dir: str = None, + plot_test_forecast: bool = False, + upper_bound_fewshot_samples: bool = True, + **kwargs, + ): + """Initialize a TTMGluonTSPredictor object. + + Args: + context_length (int): Context length. + prediction_length (int): Forecast length. + model_path (str, optional): TTM Model path.. Defaults to "ibm-granite/granite-timeseries-ttm-r2". + test_data_label (LabelDataset, optional): Test data label object. Only used for plotting. Defaults to None. + scale (bool, optional): To scale the data or not. Defaults to False. (Recommended to set to `True` for fine-tuning workflow.) + random_seed (int, optional): Seed. Defaults to 42. + term (str, optional): Term (short/medium/long). Defaults to None. + ds_name (str, optional): Dataset name. Only used for plotting. Defaults to None. + out_dir (str, optional): Out directory. Defaults to None. + plot_test_forecast (bool, optional): Whether to plot forecasts. Defaults to False. + upper_bound_fewshot_samples (bool, optional): If True, number of x% fewshot will be upper-bounded + to FEWSHOT_MAX_NUM_SAMPLES. Defaults to True. + """ + self.context_length = context_length + self.prediction_length = prediction_length + self.test_data_label = test_data_label + self.scale = scale + self.scaler = None + self.random_seed = random_seed + self.term = term + self.ds_name = ds_name + self.out_dir = out_dir + self.plot_test_forecast = plot_test_forecast + self.upper_bound_fewshot_samples = upper_bound_fewshot_samples + + self.device = "cuda" if torch.cuda.is_available() else "cpu" + + if "dropout" in kwargs and kwargs["dropout"] is None: + del kwargs["dropout"] + if "head_dropout" in kwargs and kwargs["head_dropout"] is None: + del kwargs["head_dropout"] + + # Call get_model() function to load TTM model automatically. + self.ttm = get_model( + model_path, + context_length=self.context_length, + prediction_length=min(self.prediction_length, TTM_MAX_FORECAST_HORIZON), + **kwargs, + ).to(self.device) + + def _process_time_series(self, dataset: TrainingDataset) -> List: + """ + Processes a time series by truncating initial NaNs and forward filling intermittent NaNs. + Returns a new truncated dataset, and does not modify the original one. + + Args: + dataset (TrainingDataset): Every series of of shape [channels, length]. + + Returns: + List: Processed time series, each of shape [channels, truncated_length]. + """ + truncated_dataset = list(copy.deepcopy(dataset)) + + for i, item in enumerate(truncated_dataset): + data = item["target"] + if len(data.shape) == 1: + data = data.reshape(1, -1) # [channels, length] + + # Step 1: Determine the longest stretch of initial NaNs across all channels + valid_mask = ~np.isnan(data) # Mask of valid (non-NaN) values + + if valid_mask.all(): + continue # Continue if no NaN + + first_valid = np.argmax(valid_mask.any(axis=0)) # First col with any valid value across channels + data = data[:, first_valid:] # Truncate cols before the first valid col + + # Step 2: Perform forward fill for NaNs + df = pd.DataFrame(data.T, columns=range(data.shape[0])) + df = df.ffill(axis=0) + + data = df.values.T + if data.shape[0] == 1: # [1, truncated_length] + data = data.reshape(-1) # [lentruncated_lengthgth] + + truncated_dataset[i]["target"] = data + + return truncated_dataset + + def train( + self, + train_dataset: TrainingDataset, + valid_dataset: TrainingDataset, + batch_size: int = 64, + freeze_backbone: bool = True, + learning_rate: float = None, + num_epochs: int = 30, + num_workers: int = 8, + fewshot_fraction: int = 1.0, + use_valid_from_train: bool = True, + save_model: bool = False, + ): + """Finetune the TTM. + + Args: + train_dataset (TrainingDataset): Training dataset. + valid_dataset (TrainingDataset): Validation dataset. + batch_size (int, optional): Batch size. Defaults to 64. + freeze_backbone (bool, optional): To freeze TTM backbone. Defaults to True. + learning_rate (float, optional): Learning rate. Defaults to None. + num_epochs (int, optional): Number of epochs. Defaults to 30. + num_workers (int, optional): Number of workers. Defaults to 8. + fewshot_fraction (int, optional): Few-shot fraction. Defaults to 1.0. + use_valid_from_train (bool, optional): Utilize unused train data for validation. Defaults to True. + save_model (bool, optional): Save model to `self.out_dir`. Defaults to False. + + Raises: + ValueError: _description_ + """ + train_dataset_scaled = self._process_time_series(train_dataset) + valid_dataset_scaled = self._process_time_series(valid_dataset) + + # Standard scale + if self.scale: + self.scaler = StandardScalingGluonTSDataset() + self.scaler.fit(train_dataset_scaled) + train_dataset_scaled = self.scaler.transform(train_dataset_scaled) + valid_dataset_scaled = self.scaler.transform(valid_dataset_scaled) + + temp_dir = tempfile.mkdtemp() + dset_train = TorchDatasetFromGluonTSTrainingDataset( + train_dataset_scaled, self.context_length, self.prediction_length + ) + + dset_valid_from_train = None + if fewshot_fraction < 1.0: + # Choose randomly + rng = np.random.default_rng(seed=self.random_seed) + if self.upper_bound_fewshot_samples: + list_size = min(int(fewshot_fraction * len(dset_train)), FEWSHOT_MAX_NUM_SAMPLES) + else: + list_size = int(fewshot_fraction * len(dset_train)) + + lst_fewshot_indx = rng.integers( + low=0, + high=len(dset_train), + size=list_size, + ) + + logger.info(f"Length of orginal train set = {len(dset_train)}") + org_dset_train = copy.deepcopy(dset_train) + dset_train = Subset(org_dset_train, lst_fewshot_indx) + logger.info(f"Length of {fewshot_fraction*100} % train set = {len(dset_train)}") + + if len(dset_train) < 1: + raise ValueError( + f"Data too small for finetuning in fewshot {fewshot_fraction*100}%. Resulting in 0 samples." + ) + + if use_valid_from_train: + all_indx = list(range(0, len(org_dset_train))) + valid_indx = list(set(all_indx) - set(lst_fewshot_indx)) + + # we don't use a huge validation set + valid_size = min(len(dset_train), len(org_dset_train) - len(dset_train)) + + valid_indx = np.random.choice(valid_indx, valid_size, replace=False) + dset_valid_from_train = Subset(org_dset_train, valid_indx) + + dset_valid = TorchDatasetFromGluonTSTrainingDataset( + valid_dataset_scaled, + self.context_length, + self.prediction_length, + last_window_only=True, + ) + + if dset_valid_from_train is not None: + dset_valid = ConcatDataset((dset_valid_from_train, dset_valid)) + + self.train_num_samples = len(dset_train) + self.valid_num_samples = len(dset_valid) + + if freeze_backbone: + print( + "Number of params before freezing backbone", + count_parameters(self.ttm), + ) + + # Freeze the backbone of the model + for param in self.ttm.backbone.parameters(): + param.requires_grad = False + + # Count params + print( + "Number of params after freezing the backbone", + count_parameters(self.ttm), + ) + + # Find optimal learning rate + # Use with caution: Set it manually if the suggested learning rate is not suitable + if learning_rate is None: + learning_rate, self.ttm = optimal_lr_finder( + self.ttm, + dset_train, + batch_size=batch_size, + ) + print("OPTIMAL SUGGESTED LEARNING RATE =", learning_rate) + + print(f"Using learning rate = {learning_rate}") + finetune_forecast_args = TrainingArguments( + output_dir=os.path.join(temp_dir, "output"), + overwrite_output_dir=True, + learning_rate=learning_rate, + num_train_epochs=num_epochs, + do_eval=True, + evaluation_strategy="epoch", + per_device_train_batch_size=batch_size, + per_device_eval_batch_size=batch_size, + dataloader_num_workers=num_workers, + report_to="none", + save_strategy="epoch", + logging_strategy="epoch", + save_total_limit=1, + logging_dir=os.path.join(temp_dir, "logs"), # Make sure to specify a logging directory + load_best_model_at_end=True, # Load the best model when training ends + metric_for_best_model="eval_loss", # Metric to monitor for early stopping + greater_is_better=False, # For loss + seed=self.random_seed, + ) + + # Create the early stopping callback + early_stopping_callback = EarlyStoppingCallback( + early_stopping_patience=5, # Number of epochs with no improvement after which to stop + early_stopping_threshold=1e-5, # Minimum improvement required to consider as improvement + ) + tracking_callback = TrackingCallback() + + # Optimizer and scheduler + optimizer = AdamW(self.ttm.parameters(), lr=learning_rate) + scheduler = OneCycleLR( + optimizer, + learning_rate, + epochs=num_epochs, + steps_per_epoch=math.ceil(len(dset_train) / (batch_size)), + ) + + hf_trainer = Trainer( + model=self.ttm, + args=finetune_forecast_args, + train_dataset=dset_train, + eval_dataset=dset_valid, + callbacks=[early_stopping_callback, tracking_callback], + optimizers=(optimizer, scheduler), + ) + hf_trainer.remove_callback(INTEGRATION_TO_CALLBACK["codecarbon"]) + + # Fine tune + hf_trainer.train() + + # Save model + if save_model: + hf_trainer.save_model(os.path.join(self.out_dir, "ttm_model")) + + def validate( + self, + valid_dataset: TrainingDataset, + batch_size: int = 64, + ): + """(Optionally) Validate. + + Args: + valid_dataset (TrainingDataset): Validation dataset. + batch_size (int, optional): Batch size. Defaults to 64. + + Returns: + flat: Validation loss. + """ + valid_dataset_scaled = self._process_time_series(valid_dataset) + if self.scale: + if self.scaler is None: + self.scaler = StandardScalingGluonTSDataset() + self.scaler.fit(valid_dataset_scaled) + + valid_dataset_scaled = self.scaler.transform(valid_dataset_scaled) + else: + valid_dataset_scaled = valid_dataset + + temp_dir = tempfile.mkdtemp() + dset_valid = TorchDatasetFromGluonTSTrainingDataset( + valid_dataset_scaled, + self.context_length, + self.prediction_length, + last_window_only=True, + ) + + # hf_trainer + hf_trainer = Trainer( + model=self.ttm, + args=TrainingArguments( + output_dir=temp_dir, + per_device_eval_batch_size=batch_size, + seed=self.random_seed, + report_to="none", + eval_accumulation_steps=10, + ), + ) + + # evaluate = zero-shot performance + print("+" * 20, "Zero-shot Test Loss", "+" * 20) + zeroshot_output = hf_trainer.predict(dset_valid) + print(zeroshot_output) + return zeroshot_output["eval_loss"] + + def predict( + self, + test_data_input: InputDataset, + batch_size: int = 64, + ): + """Predict. + + Args: + test_data_input (InputDataset): Test input dataset. + batch_size (int, optional): Batch size. Defaults to 64. + + Returns: + float: Eval loss. + """ + # Standard scale + if self.scale: + # We do not truncate the initial NaNs during testing since it sometimes + # results in extremely short length, and inference fails. + # Hence, in the current implementation the initial NaNs will be converted + # to zeros. + # (not used currently) test_data_input_scaled = self._process_time_series(test_data_input) + + # A new Standard Scaler is defined + # Note: Issue with using the train scaler directly...number of series mismatch! + test_data_input_scaled = copy.deepcopy(test_data_input) + scaler = StandardScalingGluonTSDataset() + scaler.fit(test_data_input_scaled) + test_data_input_scaled = scaler.transform(test_data_input_scaled) + else: + test_data_input_scaled = test_data_input + + while True: + try: + # Generate forecast samples + forecast_samples = [] + for batch in tqdm(batcher(test_data_input_scaled, batch_size=batch_size)): + batch_ttm = {} + adjusted_batch_raw = [] + for idx, entry in enumerate(batch): + # univariate array of shape (time,) + # multivariate array of shape (var, time) + # TTM supports multivariate time series + if len(entry["target"].shape) == 1: + entry["target"] = entry["target"].reshape(1, -1) + + entry_context_length = entry["target"].shape[1] + num_channels = entry["target"].shape[0] + # Pad + if entry_context_length < self.ttm.config.context_length: + padding = torch.zeros( + ( + num_channels, + self.ttm.config.context_length - entry_context_length, + ) + ) + adjusted_entry = torch.cat((padding, torch.tensor(impute_series(entry["target"]))), dim=1) + # observed_mask[idx, :, :(ttm.config.context_length - entry_context_length)] = 0 + # Truncate + elif entry_context_length > self.ttm.config.context_length: + adjusted_entry = torch.tensor( + impute_series(entry["target"][:, -self.ttm.config.context_length :]) + ) + # Take full context + else: + adjusted_entry = torch.tensor(impute_series(entry["target"])) + adjusted_batch_raw.append(adjusted_entry) + + # For TTM channel dimension comes at the end + batch_ttm["past_values"] = torch.stack(adjusted_batch_raw).permute(0, 2, 1).to(self.device) + + if self.prediction_length > TTM_MAX_FORECAST_HORIZON: + recursive_steps = int(np.ceil(self.prediction_length / self.ttm.config.prediction_length)) + predict_outputs = torch.empty(len(batch), 0, num_channels).to(self.device) + with torch.no_grad(): + for i in range(recursive_steps): + model_outputs = self.ttm(**batch_ttm) + batch_ttm["past_values"] = torch.cat( + [ + batch_ttm["past_values"], + model_outputs["prediction_outputs"], + ], + dim=1, + )[:, -self.ttm.config.context_length :, :] + predict_outputs = torch.cat( + [ + predict_outputs, + model_outputs["prediction_outputs"][:, : self.ttm.config.prediction_length, :], + ], + dim=1, + ) + predict_outputs = predict_outputs[:, : self.prediction_length, :] + else: + model_outputs = self.ttm(**batch_ttm) + predict_outputs = model_outputs.prediction_outputs + + # Accumulate all forecasts + forecast_samples.append(predict_outputs.detach().cpu().numpy()) + + # list to np.ndarray + forecast_samples = np.concatenate(forecast_samples) + + if self.scale: + # inverse scale + forecast_samples = scaler.inverse_transform(forecast_samples) + + if forecast_samples.shape[2] == 1: + forecast_samples = np.squeeze(forecast_samples, axis=2) + break + except torch.cuda.OutOfMemoryError: + print(f"OutOfMemoryError at batch_size {batch_size}, reducing to {batch_size // 2}") + batch_size //= 2 + + # Convert forecast samples into gluonts SampleForecast objects + # Array of size (num_samples, prediction_length) (1D case) or + # (num_samples, prediction_length, target_dim) (multivariate case) + sample_forecasts = [] + for item, ts in zip(forecast_samples, test_data_input): + forecast_start_date = ts["start"] + len(ts["target"]) + sample_forecasts.append( + SampleForecast( + item_id=ts["item_id"], + samples=np.expand_dims(item, axis=0), + start_date=forecast_start_date, + ) + ) + + if self.out_dir is None: + self.out_dir = tempfile.mkdtemp() + + if self.plot_test_forecast and self.prediction_length <= TTM_MAX_FORECAST_HORIZON: + # Create torch dataset for plotting + torch_dset_test = TorchDatasetFromGluonTSTestDataset( + gluon_test_input=test_data_input, + gluon_test_label=self.test_data_label, + seq_len=self.ttm.config.context_length, + forecast_len=self.prediction_length, + ) + # Plot random samples + plot_predictions( + dset=torch_dset_test, + model=self.ttm, + plot_dir=f"{self.out_dir}/{self.ds_name}_{self.term}", + channel=0, + plot_context=int(0.5 * self.prediction_length), + ) + + return sample_forecasts diff --git a/tsfm_public/toolkit/ddp_utils.py b/tsfm_public/toolkit/ddp_utils.py new file mode 100644 index 00000000..c414602f --- /dev/null +++ b/tsfm_public/toolkit/ddp_utils.py @@ -0,0 +1,24 @@ +import os +from datetime import timedelta + +import torch + + +def init_ddp(timeout=600): + local_rank = int(os.environ.get("LOCAL_RANK")) + world_size = int(os.environ.get("WORLD_SIZE")) + rank = int(os.environ.get("RANK")) + + torch.cuda.set_device(local_rank) + torch.distributed.init_process_group( + "nccl", + init_method="env://", + world_size=world_size, + rank=rank, + timeout=timedelta(seconds=timeout), + ) + + +def is_rank_0(): + rank = torch.distributed.get_rank() + return rank == 0 diff --git a/tsfm_public/toolkit/get_model.py b/tsfm_public/toolkit/get_model.py index 48cf7022..12e28179 100644 --- a/tsfm_public/toolkit/get_model.py +++ b/tsfm_public/toolkit/get_model.py @@ -115,6 +115,17 @@ def get_model( {SUPPORTED_LENGTHS[model_path_type]}" ) + # Choose closest context length + available_context_lens = sorted(SUPPORTED_LENGTHS[model_path_type]["CL"], reverse=True) + selected_context_length = None + for cl in available_context_lens: + if cl <= context_length: + selected_context_length = cl + break + if selected_context_length is None: + raise ValueError(f"Requested context length is too short. Requested = {context_length}.\n\ + Available lengths for model_type = {model_path_type} are: {available_context_lens}.") + if freq_prefix_tuning is None: # Default model preference (freq / nofreq) if model_path_type == 1 or model_path_type == 2: # for granite use nofreq models @@ -135,11 +146,11 @@ def get_model( try: if model_path_type == 1 or model_path_type == 2: ttm_model_revision = model_revisions["ibm-granite-models"][ - f"r{model_path_type}-{context_length}-{selected_prediction_length}-{freq_prefix}" + f"r{model_path_type}-{selected_context_length}-{selected_prediction_length}-{freq_prefix}" ]["revision"] elif model_path_type == 3: ttm_model_revision = model_revisions["research-use-models"][ - f"r2-{context_length}-{selected_prediction_length}-{freq_prefix}" + f"r2-{selected_context_length}-{selected_prediction_length}-{freq_prefix}" ]["revision"] else: raise Exception( @@ -149,9 +160,10 @@ def get_model( raise ValueError( f"Model not found, possibly because of wrong context_length. Supported context lengths (CL) and forecast/prediction lengths (FL) for Model Card: {model_path} are {SUPPORTED_LENGTHS[model_path_type]}" ) + else: + prediction_filter_length = prediction_length # Load model - model = TinyTimeMixerForPrediction.from_pretrained( model_path, revision=ttm_model_revision, diff --git a/tsfm_public/toolkit/gluonts_data_wrapper.py b/tsfm_public/toolkit/gluonts_data_wrapper.py new file mode 100644 index 00000000..e96043d1 --- /dev/null +++ b/tsfm_public/toolkit/gluonts_data_wrapper.py @@ -0,0 +1,245 @@ +import bisect +from typing import Union + +import numpy as np +import torch +from gluonts.dataset.split import InputDataset, LabelDataset, TrainingDataset +from gluonts.itertools import batcher +from gluonts.transform.feature import LastValueImputation +from torch.utils.data import Dataset +from tqdm import tqdm + + +def impute_series(target): + if np.isnan(target).any(): + target = target.copy() + if len(target.shape) == 2: + for i in range(target.shape[0]): + target[i, ...] = LastValueImputation()(target[i, ...]) + elif len(target.shape) == 1: + target = LastValueImputation()(target) + else: + raise Exception("Only 1D and 2D arrays are accepted by the impute_series() function.") + return target + + +def np_to_torch(np): + if np.dtype == "float" or np.dtype == "float32": + return torch.from_numpy(np).float() + elif np.dtype == "int": + return torch.from_numpy(np) + + +def _torch(*nps): + return tuple(np_to_torch(x) for x in nps) + + +class StandardScalingGluonTSDataset: + """ + TTM works best on standard scaled data, especially if fewshot + finetuning is being performed. + We can utilize the entire available context to do that. + This is a global sclaing operation done independently on + each channel. + """ + + def __init__(self) -> None: + self.mean = [] + self.std = [] + + def fit(self, train_data: Union[TrainingDataset, InputDataset]): + """Calculate the statistics on the historical train data. + + Args: + train_data (Union[TrainingDataset, InputDataset]): Iterator with + each series of shape [num_channels, seq_len] for multivariate + and [seq_len] for univariate. + """ + for batch in tqdm(batcher(train_data, batch_size=1)): + if batch[0]["target"].ndim == 1: + batch[0]["target"] = batch[0]["target"].reshape(1, -1) # [1, seq_len] + self.mean.append(np.mean(impute_series(batch[0]["target"]), axis=1).reshape(-1, 1)) + std = np.std(impute_series(batch[0]["target"]), axis=1).reshape(-1, 1) + for i in range(std.shape[0]): + if std[i] == 0: + std[i] = 1 + self.std.append(std) + + def transform(self, data: Union[TrainingDataset, InputDataset]): + """Apply scaler using calculated statistics. + + Args: + data (Union[TrainingDataset, InputDataset]): Iterator with + each series of shape [num_channels, seq_len] for multivariate + and [seq_len] for univariate. + + Returns: + Iternator: With each series transformed. + """ + assert len(self.mean) > 0 + assert len(self.std) > 0 + + out = list(data) + for i, _ in tqdm(enumerate(out)): + out[i]["target"] = (impute_series(out[i]["target"]) - self.mean[i]) / (self.std[i]) + return iter(out) + + def inverse_transform(self, data: np.ndarray) -> np.ndarray: + """Inverse transform, and bring data to original scale. + + Args: + data (np.ndarray): Forecast output of shape [batch, seq_len, num_channels] + + Raises: + Exception: If NaN is found in the forecast. + + Returns: + np.ndarray: Of shape [batch, seq_len, num_channels]. + """ + out = np.zeros(data.shape) + for i in tqdm(range((data.shape[0]))): + out[i, ...] = data[i, ...] * (self.std[i].T) + self.mean[i].T + if np.isnan(out[i, ...]).any(): + raise Exception("NaN found in forecast!") + return out + + +class TorchDatasetFromGluonTSTrainingDataset(Dataset): + def __init__( + self, + gluon_dataset: TrainingDataset, + seq_len: int, + forecast_len: int, + last_window_only=False, + ): + """Wrapper to create pytorch `Dataset` from GluonTS dataset. + + Args: + gluon_dataset (TrainingDataset): GluonTS dataset. + seq_len (int): Context length. + forecast_len (int): Forecast horizon. + last_window_only (bool, optional): If True, only last window will be processed. Defaults to False. + """ + # assert seq_len > forecast_len, f'sequence lenght {seq_len} has to be strictly greater than forecast length {forecast_len}' + self.seq_len = seq_len + self.forecast_len = forecast_len + self.X = list(gluon_dataset) + self.last_window_only = last_window_only + self.stride = 1 # TODO: support other strides + + # handle univariate series, and nans + for i, _ in enumerate(self.X): + if len(self.X[i]["target"].shape) == 1: + self.X[i]["target"] = self.X[i]["target"].reshape(1, -1) + + # Nan imputation + self.X[i]["target"] = impute_series(self.X[i]["target"]) + + # pad zeros if needed + if self.X[i]["target"].shape[1] < self.seq_len + self.forecast_len: + pad = np.zeros( + ( + self.X[i]["target"].shape[0], + self.seq_len + self.forecast_len - self.X[i]["target"].shape[1] + 1, + ) + ) + self.X[i]["target"] = np.concatenate((pad, self.X[i]["target"]), axis=1) + + # get shape + if not self.last_window_only: + self.cumulative_sizes = self.cumsum(self.X) + + def cumsum(self, list_data): + """ + list_data: list of numpy array of shape [channels x len] + """ + list_len, sum_ = [], 0 + for i, elm in enumerate(list_data): + data = elm["target"] + len_ = data.shape[1] - self.seq_len - self.forecast_len + 1 + list_len.append(len_ + sum_) + sum_ += len_ + return list_len + + def __len__(self): + if self.last_window_only: + return len(self.X) # = num of series + else: + return self.cumulative_sizes[-1] // self.stride + + def __getitem__(self, idx): + if self.last_window_only: + seq_x = self.X[idx]["target"][:, -(self.seq_len + self.forecast_len) : -self.forecast_len] + seq_y = self.X[idx]["target"][:, -(self.forecast_len) :] + else: + idx = idx * self.stride + if idx < 0: + if -idx > len(self): + raise ValueError("absolute value of index should not exceed dataset length") + idx = len(self) + idx + series_idx = bisect.bisect_right(self.cumulative_sizes, idx) + if series_idx == 0: + time_id = idx + else: + time_id = idx - self.cumulative_sizes[series_idx - 1] + seq_x = self.X[series_idx]["target"][:, time_id : time_id + self.seq_len] + seq_y = self.X[series_idx]["target"][ + :, time_id + self.seq_len : time_id + self.seq_len + self.forecast_len + ] + + # return torch.from_numpy(seq_x.astype(np.float)).float() + seq_x, seq_y = _torch(seq_x, seq_y) + + return_output = { + "past_values": seq_x.T, + "future_values": seq_y.T, + } + + return return_output + + +class TorchDatasetFromGluonTSTestDataset(Dataset): + def __init__( + self, + gluon_test_input: InputDataset, + gluon_test_label: LabelDataset, + seq_len: int, + forecast_len: int, + ): + """Wrapper to create pytorch `Dataset` from GluonTS dataset. + + Args: + gluon_dataset (TrainingDataset): GluonTS dataset. + seq_len (int): Context length. + forecast_len (int): Forecast horizon. + last_window_only (bool, optional): If True, only last window will be processed. Defaults to False. + """ + # assert seq_len > forecast_len, f'sequence lenght {seq_len} has to be strictly greater than forecast length {forecast_len}' + self.seq_len = seq_len + self.forecast_len = forecast_len + self.X = list(gluon_test_input) + self.Y = list(gluon_test_label) + + def __len__(self): + return len(self.Y) + + def __getitem__(self, idx): + seq_x = self.X[idx]["target"] + seq_y = self.Y[idx]["target"] + + if len(seq_x.shape) == 1: + seq_x = seq_x.reshape(1, -1) + seq_y = seq_y.reshape(1, -1) + + if seq_x.shape[1] < self.seq_len: + pad = np.zeros((seq_x.shape[0], self.seq_len - seq_x.shape[1])) + seq_x = np.concatenate((pad, seq_x), axis=1) + + seq_x, seq_y = _torch(seq_x[:, -self.seq_len :], seq_y[:, : self.forecast_len]) + + return_output = { + "past_values": seq_x.T, + "future_values": seq_y.T, + } + + return return_output diff --git a/tsfm_public/toolkit/visualization.py b/tsfm_public/toolkit/visualization.py index 658a469e..93f6623c 100644 --- a/tsfm_public/toolkit/visualization.py +++ b/tsfm_public/toolkit/visualization.py @@ -309,7 +309,7 @@ def plot_predictions( if k in signature_keys: random_samples[k] = torch.stack([dset[i][k] for i in indices]).to(device=device) output = model(**random_samples) - predictions_subset = output.prediction_outputs[:, :, channel].squeeze().cpu().numpy() + predictions_subset = output.prediction_outputs[:, :, channel].cpu().numpy() prediction_length = predictions_subset.shape[1] using_pipeline = False plot_test_data = True From 3598090ba2dca854d9af8226e1b9e814b3471026 Mon Sep 17 00:00:00 2001 From: Arindam Jati Date: Wed, 4 Dec 2024 00:28:44 -0500 Subject: [PATCH 05/23] enable truncation of context len in ttm --- pyproject.toml | 5 +++-- .../models/tinytimemixer/modeling_tinytimemixer.py | 9 +++++++++ tsfm_public/toolkit/gluonts_data_wrapper.py | 14 ++------------ 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 60a7c490..d21fb799 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,7 @@ packages = ["tsfm_public", "tsfmhfdemos"] [project.optional-dependencies] -all = ["tsfm_public[notebooks,testing,dev]"] +all = ["tsfm_public[notebooks,external,testing,dev]"] notebooks = [ "jupyter", @@ -42,7 +42,8 @@ notebooks = [ "kaleido", "tensorboard", ] -testing = ["pytest", "tsfm_public[notebooks]", "parameterized"] +external = ["tsfm_public[notebooks]", "gluonts"] +testing = ["pytest", "tsfm_public[external]", "parameterized"] dev = ["pre-commit", "tsfm_public[testing]", "ruff==0.4.4"] # ogv deployments will already have jupyter diff --git a/tsfm_public/models/tinytimemixer/modeling_tinytimemixer.py b/tsfm_public/models/tinytimemixer/modeling_tinytimemixer.py index 2cf77f35..7ec8651e 100644 --- a/tsfm_public/models/tinytimemixer/modeling_tinytimemixer.py +++ b/tsfm_public/models/tinytimemixer/modeling_tinytimemixer.py @@ -1802,6 +1802,15 @@ def forward( Returns: """ + if past_values.dim() != 3: + raise ValueError( + "`past_values` must have 3 dimensions of shape `(batch_size, sequence_length, num_input_channels)`." + ) + if past_values.shape[1] > self.config.context_length: + past_values = past_values[:, -self.config.context_length :, :] + elif past_values.shape[1] < self.config.context_length: + raise ValueError("Context length in `past_values` is shorter that TTM context_length.") + if self.loss == "mse": loss = nn.MSELoss(reduction="mean") elif self.loss == "mae": diff --git a/tsfm_public/toolkit/gluonts_data_wrapper.py b/tsfm_public/toolkit/gluonts_data_wrapper.py index e96043d1..40e58c1c 100644 --- a/tsfm_public/toolkit/gluonts_data_wrapper.py +++ b/tsfm_public/toolkit/gluonts_data_wrapper.py @@ -2,13 +2,14 @@ from typing import Union import numpy as np -import torch from gluonts.dataset.split import InputDataset, LabelDataset, TrainingDataset from gluonts.itertools import batcher from gluonts.transform.feature import LastValueImputation from torch.utils.data import Dataset from tqdm import tqdm +from tsfm_public.toolkit.dataset import _torch + def impute_series(target): if np.isnan(target).any(): @@ -23,17 +24,6 @@ def impute_series(target): return target -def np_to_torch(np): - if np.dtype == "float" or np.dtype == "float32": - return torch.from_numpy(np).float() - elif np.dtype == "int": - return torch.from_numpy(np) - - -def _torch(*nps): - return tuple(np_to_torch(x) for x in nps) - - class StandardScalingGluonTSDataset: """ TTM works best on standard scaled data, especially if fewshot From 176e7d1c5f117b144bd8645bbeb05c628af7cdde Mon Sep 17 00:00:00 2001 From: Wesley Gifford <79663411+wgifford@users.noreply.github.com> Date: Wed, 4 Dec 2024 12:55:27 -0500 Subject: [PATCH 06/23] fix issues with future exogenous --- .../toolkit/time_series_forecasting_pipeline.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tsfm_public/toolkit/time_series_forecasting_pipeline.py b/tsfm_public/toolkit/time_series_forecasting_pipeline.py index aafe1eb3..8c3431db 100644 --- a/tsfm_public/toolkit/time_series_forecasting_pipeline.py +++ b/tsfm_public/toolkit/time_series_forecasting_pipeline.py @@ -372,7 +372,7 @@ def preprocess(self, time_series, **kwargs) -> Dict[str, Union[GenericTensor, Li # future data needs some values for targets, but they are unused future_time_series[target_columns] = 0 future_time_series = self.feature_extractor.preprocess(future_time_series) - future_time_series.drop(columns=target_columns) + future_time_series = future_time_series.drop(columns=target_columns) time_series = pd.concat((time_series, future_time_series), axis=0) else: @@ -424,13 +424,16 @@ def postprocess(self, input, **kwargs): # name the predictions of target columns # outputs should only have size equal to target columns + # only allow adding ground truth when not exploding + add_known_ground_truth = kwargs["add_known_ground_truth"] if not kwargs["explode_forecasts"] else False + prediction_columns = [] for i, c in enumerate(kwargs["target_columns"]): - prediction_columns.append(f"{c}_prediction" if kwargs["add_known_ground_truth"] else c) + prediction_columns.append(f"{c}_prediction" if add_known_ground_truth else c) out[prediction_columns[-1]] = input[model_output_key][:, :, i].numpy().tolist() # provide the ground truth values for the targets # when future is unknown, we will have augmented the provided dataframe with NaN values to cover the future - if kwargs["add_known_ground_truth"]: + if add_known_ground_truth: for i, c in enumerate(kwargs["target_columns"]): ground_truth = input["future_values"][:, :, i].numpy() missing = ~input["future_observed_mask"][:, :, i].numpy() @@ -478,7 +481,7 @@ def postprocess(self, input, **kwargs): # inverse scale if we have a feature extractor if self.feature_extractor is not None and kwargs["inverse_scale_outputs"]: out = self.feature_extractor.inverse_scale_targets(out) - if kwargs["add_known_ground_truth"]: + if add_known_ground_truth: out = self.feature_extractor.inverse_scale_targets(out, suffix="_prediction") return out From 3859359ebdd6fda0ca6d186c8c9d9221942cbb7d Mon Sep 17 00:00:00 2001 From: Arindam Jati Date: Wed, 4 Dec 2024 13:02:33 -0500 Subject: [PATCH 07/23] force_return in get_model --- .../gluonts/ttm_gluonts_predictor.py | 5 ++-- tsfm_public/toolkit/get_model.py | 25 ++++++++++++++++--- 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/tsfm_public/models/tinytimemixer/gluonts/ttm_gluonts_predictor.py b/tsfm_public/models/tinytimemixer/gluonts/ttm_gluonts_predictor.py index 1f94a621..b5e68fe7 100644 --- a/tsfm_public/models/tinytimemixer/gluonts/ttm_gluonts_predictor.py +++ b/tsfm_public/models/tinytimemixer/gluonts/ttm_gluonts_predictor.py @@ -83,7 +83,6 @@ def __init__( upper_bound_fewshot_samples (bool, optional): If True, number of x% fewshot will be upper-bounded to FEWSHOT_MAX_NUM_SAMPLES. Defaults to True. """ - self.context_length = context_length self.prediction_length = prediction_length self.test_data_label = test_data_label self.scale = scale @@ -105,10 +104,12 @@ def __init__( # Call get_model() function to load TTM model automatically. self.ttm = get_model( model_path, - context_length=self.context_length, + context_length=context_length, prediction_length=min(self.prediction_length, TTM_MAX_FORECAST_HORIZON), + force_return=True, **kwargs, ).to(self.device) + self.context_length = self.ttm.config.context_length def _process_time_series(self, dataset: TrainingDataset) -> List: """ diff --git a/tsfm_public/toolkit/get_model.py b/tsfm_public/toolkit/get_model.py index 12e28179..bb76c145 100644 --- a/tsfm_public/toolkit/get_model.py +++ b/tsfm_public/toolkit/get_model.py @@ -48,6 +48,7 @@ def get_model( context_length: int = None, prediction_length: int = None, freq_prefix_tuning: bool = None, + force_return: bool = False, **kwargs, ): """ @@ -102,7 +103,14 @@ def get_model( elif prediction_length <= 720: selected_prediction_length = 720 else: - raise ValueError("Currently supported maximum prediction_length = 720") + if force_return: + selected_prediction_length = 720 + LOGGER.warning( + "The requested forecast horizon is greater than the maximum supported horizon (720).\n\ + Returning TTM model with horizon 720 since `force_return=True`." + ) + else: + raise ValueError("Currently supported maximum prediction_length = 720") LOGGER.info(f"Selected prediction_length = {selected_prediction_length}") @@ -121,10 +129,21 @@ def get_model( for cl in available_context_lens: if cl <= context_length: selected_context_length = cl + if cl < context_length: + LOGGER.warning( + f"Selecting TTM context length ({selected_context_length}) < Requested context length ({context_length} since exact match was not found.)" + ) break if selected_context_length is None: - raise ValueError(f"Requested context length is too short. Requested = {context_length}.\n\ - Available lengths for model_type = {model_path_type} are: {available_context_lens}.") + if force_return: + selected_context_length = available_context_lens[-1] + LOGGER.warning(f"Requested context length is too short. Requested = {context_length}.\n\ + Available lengths for model_type = {model_path_type} are: {available_context_lens}.\n\ + Returning the shortest context length model possible since `force_return=True`.") + else: + raise ValueError(f"Requested context length is too short. Requested = {context_length}.\n\ + Available lengths for model_type = {model_path_type} are: {available_context_lens}.\n\ + To return the shortest context length model possible, set `force_return=True`.") if freq_prefix_tuning is None: # Default model preference (freq / nofreq) From beba82532231b04d0db109a0562e4e80ca5d9e55 Mon Sep 17 00:00:00 2001 From: Arindam Jati Date: Thu, 5 Dec 2024 02:10:32 -0500 Subject: [PATCH 08/23] code moved to extras folder outside tsfm_public --- .../gluonts/data}/gluonts_data_wrapper.py | 0 .../gluonts/models/tinytimemixer}/__init__.py | 2 +- .../models/tinytimemixer}/ttm_gluonts_predictor.py | 10 +++++++--- {tsfm_public/toolkit => extras/utils}/ddp_utils.py | 0 4 files changed, 8 insertions(+), 4 deletions(-) rename {tsfm_public/toolkit => extras/gluonts/data}/gluonts_data_wrapper.py (100%) rename {tsfm_public/models/tinytimemixer/gluonts => extras/gluonts/models/tinytimemixer}/__init__.py (52%) rename {tsfm_public/models/tinytimemixer/gluonts => extras/gluonts/models/tinytimemixer}/ttm_gluonts_predictor.py (99%) rename {tsfm_public/toolkit => extras/utils}/ddp_utils.py (100%) diff --git a/tsfm_public/toolkit/gluonts_data_wrapper.py b/extras/gluonts/data/gluonts_data_wrapper.py similarity index 100% rename from tsfm_public/toolkit/gluonts_data_wrapper.py rename to extras/gluonts/data/gluonts_data_wrapper.py diff --git a/tsfm_public/models/tinytimemixer/gluonts/__init__.py b/extras/gluonts/models/tinytimemixer/__init__.py similarity index 52% rename from tsfm_public/models/tinytimemixer/gluonts/__init__.py rename to extras/gluonts/models/tinytimemixer/__init__.py index 67f0d162..65da48d0 100644 --- a/tsfm_public/models/tinytimemixer/gluonts/__init__.py +++ b/extras/gluonts/models/tinytimemixer/__init__.py @@ -1,4 +1,4 @@ -from tsfm_public.models.tinytimemixer.gluonts.ttm_gluonts_predictor import ( +from .ttm_gluonts_predictor import ( TTM_AVAILABLE_CONTEXTS, TTM_MAX_FORECAST_HORIZON, TTMGluonTSPredictor, diff --git a/tsfm_public/models/tinytimemixer/gluonts/ttm_gluonts_predictor.py b/extras/gluonts/models/tinytimemixer/ttm_gluonts_predictor.py similarity index 99% rename from tsfm_public/models/tinytimemixer/gluonts/ttm_gluonts_predictor.py rename to extras/gluonts/models/tinytimemixer/ttm_gluonts_predictor.py index b5e68fe7..dfd9040b 100644 --- a/tsfm_public/models/tinytimemixer/gluonts/ttm_gluonts_predictor.py +++ b/extras/gluonts/models/tinytimemixer/ttm_gluonts_predictor.py @@ -5,6 +5,7 @@ import copy import math import os +import sys import tempfile from typing import List @@ -27,14 +28,17 @@ count_parameters, ) from tsfm_public.toolkit.get_model import get_model -from tsfm_public.toolkit.gluonts_data_wrapper import ( +from tsfm_public.toolkit.lr_finder import optimal_lr_finder +from tsfm_public.toolkit.visualization import plot_predictions + + +sys.path.append(os.path.realpath("../../../")) +from extras.gluonts.data.gluonts_data_wrapper import ( StandardScalingGluonTSDataset, TorchDatasetFromGluonTSTestDataset, TorchDatasetFromGluonTSTrainingDataset, impute_series, ) -from tsfm_public.toolkit.lr_finder import optimal_lr_finder -from tsfm_public.toolkit.visualization import plot_predictions logger = logging.get_logger(__name__) diff --git a/tsfm_public/toolkit/ddp_utils.py b/extras/utils/ddp_utils.py similarity index 100% rename from tsfm_public/toolkit/ddp_utils.py rename to extras/utils/ddp_utils.py From c0b03eb508c64c08510e0b65cd76fe4e0646c5db Mon Sep 17 00:00:00 2001 From: Arindam Jati Date: Thu, 5 Dec 2024 03:06:12 -0500 Subject: [PATCH 09/23] tests moved --- .../gluonts/tests}/test_gluonts_data_wrapper.py | 7 ++++++- .../gluonts/tests}/test_gluonts_predictor.py | 7 ++++++- 2 files changed, 12 insertions(+), 2 deletions(-) rename {tests/toolkit => extras/gluonts/tests}/test_gluonts_data_wrapper.py (97%) rename {tests/models/tinytimemixer => extras/gluonts/tests}/test_gluonts_predictor.py (92%) diff --git a/tests/toolkit/test_gluonts_data_wrapper.py b/extras/gluonts/tests/test_gluonts_data_wrapper.py similarity index 97% rename from tests/toolkit/test_gluonts_data_wrapper.py rename to extras/gluonts/tests/test_gluonts_data_wrapper.py index 8f74855f..0fe609d0 100644 --- a/tests/toolkit/test_gluonts_data_wrapper.py +++ b/extras/gluonts/tests/test_gluonts_data_wrapper.py @@ -3,13 +3,18 @@ """Tests get_model""" +import os +import sys + import numpy as np import pandas as pd import pytest from gluonts.dataset.common import ListDataset from gluonts.dataset.split import split -from tsfm_public.toolkit.gluonts_data_wrapper import ( + +sys.path.append(os.path.realpath("../../../")) +from extras.gluonts.data.gluonts_data_wrapper import ( StandardScalingGluonTSDataset, TorchDatasetFromGluonTSTestDataset, TorchDatasetFromGluonTSTrainingDataset, diff --git a/tests/models/tinytimemixer/test_gluonts_predictor.py b/extras/gluonts/tests/test_gluonts_predictor.py similarity index 92% rename from tests/models/tinytimemixer/test_gluonts_predictor.py rename to extras/gluonts/tests/test_gluonts_predictor.py index 6e8dbb56..1f9f4e58 100644 --- a/tests/models/tinytimemixer/test_gluonts_predictor.py +++ b/extras/gluonts/tests/test_gluonts_predictor.py @@ -3,12 +3,17 @@ """Tests get_model""" +import os +import sys + import numpy as np import pandas as pd import pytest from gluonts.dataset.common import ListDataset -from tsfm_public.models.tinytimemixer.gluonts import TTMGluonTSPredictor + +sys.path.append(os.path.realpath("../../")) +from extras.gluonts.models.tinytimemixer import TTMGluonTSPredictor @pytest.fixture(scope="module") From 3df4e68725140ef795fad23a665bc08816f6defb Mon Sep 17 00:00:00 2001 From: Arindam Jati Date: Thu, 5 Dec 2024 12:30:13 -0500 Subject: [PATCH 10/23] gift srcs removed, get_model updated --- extras/gluonts/data/gluonts_data_wrapper.py | 235 -------- .../gluonts/models/tinytimemixer/__init__.py | 5 - .../tinytimemixer/ttm_gluonts_predictor.py | 527 ------------------ .../tests/test_gluonts_data_wrapper.py | 129 ----- .../gluonts/tests/test_gluonts_predictor.py | 57 -- extras/utils/ddp_utils.py | 24 - tests/toolkit/test_get_model.py | 19 +- tsfm_public/toolkit/get_model.py | 46 +- 8 files changed, 37 insertions(+), 1005 deletions(-) delete mode 100644 extras/gluonts/data/gluonts_data_wrapper.py delete mode 100644 extras/gluonts/models/tinytimemixer/__init__.py delete mode 100644 extras/gluonts/models/tinytimemixer/ttm_gluonts_predictor.py delete mode 100644 extras/gluonts/tests/test_gluonts_data_wrapper.py delete mode 100644 extras/gluonts/tests/test_gluonts_predictor.py delete mode 100644 extras/utils/ddp_utils.py diff --git a/extras/gluonts/data/gluonts_data_wrapper.py b/extras/gluonts/data/gluonts_data_wrapper.py deleted file mode 100644 index 40e58c1c..00000000 --- a/extras/gluonts/data/gluonts_data_wrapper.py +++ /dev/null @@ -1,235 +0,0 @@ -import bisect -from typing import Union - -import numpy as np -from gluonts.dataset.split import InputDataset, LabelDataset, TrainingDataset -from gluonts.itertools import batcher -from gluonts.transform.feature import LastValueImputation -from torch.utils.data import Dataset -from tqdm import tqdm - -from tsfm_public.toolkit.dataset import _torch - - -def impute_series(target): - if np.isnan(target).any(): - target = target.copy() - if len(target.shape) == 2: - for i in range(target.shape[0]): - target[i, ...] = LastValueImputation()(target[i, ...]) - elif len(target.shape) == 1: - target = LastValueImputation()(target) - else: - raise Exception("Only 1D and 2D arrays are accepted by the impute_series() function.") - return target - - -class StandardScalingGluonTSDataset: - """ - TTM works best on standard scaled data, especially if fewshot - finetuning is being performed. - We can utilize the entire available context to do that. - This is a global sclaing operation done independently on - each channel. - """ - - def __init__(self) -> None: - self.mean = [] - self.std = [] - - def fit(self, train_data: Union[TrainingDataset, InputDataset]): - """Calculate the statistics on the historical train data. - - Args: - train_data (Union[TrainingDataset, InputDataset]): Iterator with - each series of shape [num_channels, seq_len] for multivariate - and [seq_len] for univariate. - """ - for batch in tqdm(batcher(train_data, batch_size=1)): - if batch[0]["target"].ndim == 1: - batch[0]["target"] = batch[0]["target"].reshape(1, -1) # [1, seq_len] - self.mean.append(np.mean(impute_series(batch[0]["target"]), axis=1).reshape(-1, 1)) - std = np.std(impute_series(batch[0]["target"]), axis=1).reshape(-1, 1) - for i in range(std.shape[0]): - if std[i] == 0: - std[i] = 1 - self.std.append(std) - - def transform(self, data: Union[TrainingDataset, InputDataset]): - """Apply scaler using calculated statistics. - - Args: - data (Union[TrainingDataset, InputDataset]): Iterator with - each series of shape [num_channels, seq_len] for multivariate - and [seq_len] for univariate. - - Returns: - Iternator: With each series transformed. - """ - assert len(self.mean) > 0 - assert len(self.std) > 0 - - out = list(data) - for i, _ in tqdm(enumerate(out)): - out[i]["target"] = (impute_series(out[i]["target"]) - self.mean[i]) / (self.std[i]) - return iter(out) - - def inverse_transform(self, data: np.ndarray) -> np.ndarray: - """Inverse transform, and bring data to original scale. - - Args: - data (np.ndarray): Forecast output of shape [batch, seq_len, num_channels] - - Raises: - Exception: If NaN is found in the forecast. - - Returns: - np.ndarray: Of shape [batch, seq_len, num_channels]. - """ - out = np.zeros(data.shape) - for i in tqdm(range((data.shape[0]))): - out[i, ...] = data[i, ...] * (self.std[i].T) + self.mean[i].T - if np.isnan(out[i, ...]).any(): - raise Exception("NaN found in forecast!") - return out - - -class TorchDatasetFromGluonTSTrainingDataset(Dataset): - def __init__( - self, - gluon_dataset: TrainingDataset, - seq_len: int, - forecast_len: int, - last_window_only=False, - ): - """Wrapper to create pytorch `Dataset` from GluonTS dataset. - - Args: - gluon_dataset (TrainingDataset): GluonTS dataset. - seq_len (int): Context length. - forecast_len (int): Forecast horizon. - last_window_only (bool, optional): If True, only last window will be processed. Defaults to False. - """ - # assert seq_len > forecast_len, f'sequence lenght {seq_len} has to be strictly greater than forecast length {forecast_len}' - self.seq_len = seq_len - self.forecast_len = forecast_len - self.X = list(gluon_dataset) - self.last_window_only = last_window_only - self.stride = 1 # TODO: support other strides - - # handle univariate series, and nans - for i, _ in enumerate(self.X): - if len(self.X[i]["target"].shape) == 1: - self.X[i]["target"] = self.X[i]["target"].reshape(1, -1) - - # Nan imputation - self.X[i]["target"] = impute_series(self.X[i]["target"]) - - # pad zeros if needed - if self.X[i]["target"].shape[1] < self.seq_len + self.forecast_len: - pad = np.zeros( - ( - self.X[i]["target"].shape[0], - self.seq_len + self.forecast_len - self.X[i]["target"].shape[1] + 1, - ) - ) - self.X[i]["target"] = np.concatenate((pad, self.X[i]["target"]), axis=1) - - # get shape - if not self.last_window_only: - self.cumulative_sizes = self.cumsum(self.X) - - def cumsum(self, list_data): - """ - list_data: list of numpy array of shape [channels x len] - """ - list_len, sum_ = [], 0 - for i, elm in enumerate(list_data): - data = elm["target"] - len_ = data.shape[1] - self.seq_len - self.forecast_len + 1 - list_len.append(len_ + sum_) - sum_ += len_ - return list_len - - def __len__(self): - if self.last_window_only: - return len(self.X) # = num of series - else: - return self.cumulative_sizes[-1] // self.stride - - def __getitem__(self, idx): - if self.last_window_only: - seq_x = self.X[idx]["target"][:, -(self.seq_len + self.forecast_len) : -self.forecast_len] - seq_y = self.X[idx]["target"][:, -(self.forecast_len) :] - else: - idx = idx * self.stride - if idx < 0: - if -idx > len(self): - raise ValueError("absolute value of index should not exceed dataset length") - idx = len(self) + idx - series_idx = bisect.bisect_right(self.cumulative_sizes, idx) - if series_idx == 0: - time_id = idx - else: - time_id = idx - self.cumulative_sizes[series_idx - 1] - seq_x = self.X[series_idx]["target"][:, time_id : time_id + self.seq_len] - seq_y = self.X[series_idx]["target"][ - :, time_id + self.seq_len : time_id + self.seq_len + self.forecast_len - ] - - # return torch.from_numpy(seq_x.astype(np.float)).float() - seq_x, seq_y = _torch(seq_x, seq_y) - - return_output = { - "past_values": seq_x.T, - "future_values": seq_y.T, - } - - return return_output - - -class TorchDatasetFromGluonTSTestDataset(Dataset): - def __init__( - self, - gluon_test_input: InputDataset, - gluon_test_label: LabelDataset, - seq_len: int, - forecast_len: int, - ): - """Wrapper to create pytorch `Dataset` from GluonTS dataset. - - Args: - gluon_dataset (TrainingDataset): GluonTS dataset. - seq_len (int): Context length. - forecast_len (int): Forecast horizon. - last_window_only (bool, optional): If True, only last window will be processed. Defaults to False. - """ - # assert seq_len > forecast_len, f'sequence lenght {seq_len} has to be strictly greater than forecast length {forecast_len}' - self.seq_len = seq_len - self.forecast_len = forecast_len - self.X = list(gluon_test_input) - self.Y = list(gluon_test_label) - - def __len__(self): - return len(self.Y) - - def __getitem__(self, idx): - seq_x = self.X[idx]["target"] - seq_y = self.Y[idx]["target"] - - if len(seq_x.shape) == 1: - seq_x = seq_x.reshape(1, -1) - seq_y = seq_y.reshape(1, -1) - - if seq_x.shape[1] < self.seq_len: - pad = np.zeros((seq_x.shape[0], self.seq_len - seq_x.shape[1])) - seq_x = np.concatenate((pad, seq_x), axis=1) - - seq_x, seq_y = _torch(seq_x[:, -self.seq_len :], seq_y[:, : self.forecast_len]) - - return_output = { - "past_values": seq_x.T, - "future_values": seq_y.T, - } - - return return_output diff --git a/extras/gluonts/models/tinytimemixer/__init__.py b/extras/gluonts/models/tinytimemixer/__init__.py deleted file mode 100644 index 65da48d0..00000000 --- a/extras/gluonts/models/tinytimemixer/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .ttm_gluonts_predictor import ( - TTM_AVAILABLE_CONTEXTS, - TTM_MAX_FORECAST_HORIZON, - TTMGluonTSPredictor, -) diff --git a/extras/gluonts/models/tinytimemixer/ttm_gluonts_predictor.py b/extras/gluonts/models/tinytimemixer/ttm_gluonts_predictor.py deleted file mode 100644 index dfd9040b..00000000 --- a/extras/gluonts/models/tinytimemixer/ttm_gluonts_predictor.py +++ /dev/null @@ -1,527 +0,0 @@ -# Copyright contributors to the TSFM project -# -"""Tools for building TTM Predictor that works with GluonTS datasets""" - -import copy -import math -import os -import sys -import tempfile -from typing import List - -import numpy as np -import pandas as pd -import torch -from gluonts.dataset.split import InputDataset, LabelDataset, TrainingDataset -from gluonts.itertools import batcher -from gluonts.model.forecast import SampleForecast -from torch.optim import AdamW -from torch.optim.lr_scheduler import OneCycleLR -from torch.utils.data import ConcatDataset, Subset -from tqdm.auto import tqdm -from transformers import EarlyStoppingCallback, Trainer, TrainingArguments -from transformers.integrations import INTEGRATION_TO_CALLBACK -from transformers.utils import logging - -from tsfm_public import ( - TrackingCallback, - count_parameters, -) -from tsfm_public.toolkit.get_model import get_model -from tsfm_public.toolkit.lr_finder import optimal_lr_finder -from tsfm_public.toolkit.visualization import plot_predictions - - -sys.path.append(os.path.realpath("../../../")) -from extras.gluonts.data.gluonts_data_wrapper import ( - StandardScalingGluonTSDataset, - TorchDatasetFromGluonTSTestDataset, - TorchDatasetFromGluonTSTrainingDataset, - impute_series, -) - - -logger = logging.get_logger(__name__) - -# TTM Constants -TTM_MAX_FORECAST_HORIZON = 720 -TTM_AVAILABLE_CONTEXTS = [1536, 1024, 512] - -# Fewshot max allowed number of samples -# For example, if 5% few-shot for a dataset exceeds this number, -# this `FEWSHOT_MAX_NUM_SAMPLES` upper bound will be used. -FEWSHOT_MAX_NUM_SAMPLES = 500_000 - - -class TTMGluonTSPredictor: - """Wrapper to TTM that can be directly trained, validated, and tested with GluonTS datasets.""" - - def __init__( - self, - context_length: int, - prediction_length: int, - model_path: str = "ibm-granite/granite-timeseries-ttm-r2", - test_data_label: LabelDataset = None, # provide this for plotting - scale: bool = False, - random_seed: int = 42, - term: str = None, - ds_name: str = None, - out_dir: str = None, - plot_test_forecast: bool = False, - upper_bound_fewshot_samples: bool = True, - **kwargs, - ): - """Initialize a TTMGluonTSPredictor object. - - Args: - context_length (int): Context length. - prediction_length (int): Forecast length. - model_path (str, optional): TTM Model path.. Defaults to "ibm-granite/granite-timeseries-ttm-r2". - test_data_label (LabelDataset, optional): Test data label object. Only used for plotting. Defaults to None. - scale (bool, optional): To scale the data or not. Defaults to False. (Recommended to set to `True` for fine-tuning workflow.) - random_seed (int, optional): Seed. Defaults to 42. - term (str, optional): Term (short/medium/long). Defaults to None. - ds_name (str, optional): Dataset name. Only used for plotting. Defaults to None. - out_dir (str, optional): Out directory. Defaults to None. - plot_test_forecast (bool, optional): Whether to plot forecasts. Defaults to False. - upper_bound_fewshot_samples (bool, optional): If True, number of x% fewshot will be upper-bounded - to FEWSHOT_MAX_NUM_SAMPLES. Defaults to True. - """ - self.prediction_length = prediction_length - self.test_data_label = test_data_label - self.scale = scale - self.scaler = None - self.random_seed = random_seed - self.term = term - self.ds_name = ds_name - self.out_dir = out_dir - self.plot_test_forecast = plot_test_forecast - self.upper_bound_fewshot_samples = upper_bound_fewshot_samples - - self.device = "cuda" if torch.cuda.is_available() else "cpu" - - if "dropout" in kwargs and kwargs["dropout"] is None: - del kwargs["dropout"] - if "head_dropout" in kwargs and kwargs["head_dropout"] is None: - del kwargs["head_dropout"] - - # Call get_model() function to load TTM model automatically. - self.ttm = get_model( - model_path, - context_length=context_length, - prediction_length=min(self.prediction_length, TTM_MAX_FORECAST_HORIZON), - force_return=True, - **kwargs, - ).to(self.device) - self.context_length = self.ttm.config.context_length - - def _process_time_series(self, dataset: TrainingDataset) -> List: - """ - Processes a time series by truncating initial NaNs and forward filling intermittent NaNs. - Returns a new truncated dataset, and does not modify the original one. - - Args: - dataset (TrainingDataset): Every series of of shape [channels, length]. - - Returns: - List: Processed time series, each of shape [channels, truncated_length]. - """ - truncated_dataset = list(copy.deepcopy(dataset)) - - for i, item in enumerate(truncated_dataset): - data = item["target"] - if len(data.shape) == 1: - data = data.reshape(1, -1) # [channels, length] - - # Step 1: Determine the longest stretch of initial NaNs across all channels - valid_mask = ~np.isnan(data) # Mask of valid (non-NaN) values - - if valid_mask.all(): - continue # Continue if no NaN - - first_valid = np.argmax(valid_mask.any(axis=0)) # First col with any valid value across channels - data = data[:, first_valid:] # Truncate cols before the first valid col - - # Step 2: Perform forward fill for NaNs - df = pd.DataFrame(data.T, columns=range(data.shape[0])) - df = df.ffill(axis=0) - - data = df.values.T - if data.shape[0] == 1: # [1, truncated_length] - data = data.reshape(-1) # [lentruncated_lengthgth] - - truncated_dataset[i]["target"] = data - - return truncated_dataset - - def train( - self, - train_dataset: TrainingDataset, - valid_dataset: TrainingDataset, - batch_size: int = 64, - freeze_backbone: bool = True, - learning_rate: float = None, - num_epochs: int = 30, - num_workers: int = 8, - fewshot_fraction: int = 1.0, - use_valid_from_train: bool = True, - save_model: bool = False, - ): - """Finetune the TTM. - - Args: - train_dataset (TrainingDataset): Training dataset. - valid_dataset (TrainingDataset): Validation dataset. - batch_size (int, optional): Batch size. Defaults to 64. - freeze_backbone (bool, optional): To freeze TTM backbone. Defaults to True. - learning_rate (float, optional): Learning rate. Defaults to None. - num_epochs (int, optional): Number of epochs. Defaults to 30. - num_workers (int, optional): Number of workers. Defaults to 8. - fewshot_fraction (int, optional): Few-shot fraction. Defaults to 1.0. - use_valid_from_train (bool, optional): Utilize unused train data for validation. Defaults to True. - save_model (bool, optional): Save model to `self.out_dir`. Defaults to False. - - Raises: - ValueError: _description_ - """ - train_dataset_scaled = self._process_time_series(train_dataset) - valid_dataset_scaled = self._process_time_series(valid_dataset) - - # Standard scale - if self.scale: - self.scaler = StandardScalingGluonTSDataset() - self.scaler.fit(train_dataset_scaled) - train_dataset_scaled = self.scaler.transform(train_dataset_scaled) - valid_dataset_scaled = self.scaler.transform(valid_dataset_scaled) - - temp_dir = tempfile.mkdtemp() - dset_train = TorchDatasetFromGluonTSTrainingDataset( - train_dataset_scaled, self.context_length, self.prediction_length - ) - - dset_valid_from_train = None - if fewshot_fraction < 1.0: - # Choose randomly - rng = np.random.default_rng(seed=self.random_seed) - if self.upper_bound_fewshot_samples: - list_size = min(int(fewshot_fraction * len(dset_train)), FEWSHOT_MAX_NUM_SAMPLES) - else: - list_size = int(fewshot_fraction * len(dset_train)) - - lst_fewshot_indx = rng.integers( - low=0, - high=len(dset_train), - size=list_size, - ) - - logger.info(f"Length of orginal train set = {len(dset_train)}") - org_dset_train = copy.deepcopy(dset_train) - dset_train = Subset(org_dset_train, lst_fewshot_indx) - logger.info(f"Length of {fewshot_fraction*100} % train set = {len(dset_train)}") - - if len(dset_train) < 1: - raise ValueError( - f"Data too small for finetuning in fewshot {fewshot_fraction*100}%. Resulting in 0 samples." - ) - - if use_valid_from_train: - all_indx = list(range(0, len(org_dset_train))) - valid_indx = list(set(all_indx) - set(lst_fewshot_indx)) - - # we don't use a huge validation set - valid_size = min(len(dset_train), len(org_dset_train) - len(dset_train)) - - valid_indx = np.random.choice(valid_indx, valid_size, replace=False) - dset_valid_from_train = Subset(org_dset_train, valid_indx) - - dset_valid = TorchDatasetFromGluonTSTrainingDataset( - valid_dataset_scaled, - self.context_length, - self.prediction_length, - last_window_only=True, - ) - - if dset_valid_from_train is not None: - dset_valid = ConcatDataset((dset_valid_from_train, dset_valid)) - - self.train_num_samples = len(dset_train) - self.valid_num_samples = len(dset_valid) - - if freeze_backbone: - print( - "Number of params before freezing backbone", - count_parameters(self.ttm), - ) - - # Freeze the backbone of the model - for param in self.ttm.backbone.parameters(): - param.requires_grad = False - - # Count params - print( - "Number of params after freezing the backbone", - count_parameters(self.ttm), - ) - - # Find optimal learning rate - # Use with caution: Set it manually if the suggested learning rate is not suitable - if learning_rate is None: - learning_rate, self.ttm = optimal_lr_finder( - self.ttm, - dset_train, - batch_size=batch_size, - ) - print("OPTIMAL SUGGESTED LEARNING RATE =", learning_rate) - - print(f"Using learning rate = {learning_rate}") - finetune_forecast_args = TrainingArguments( - output_dir=os.path.join(temp_dir, "output"), - overwrite_output_dir=True, - learning_rate=learning_rate, - num_train_epochs=num_epochs, - do_eval=True, - evaluation_strategy="epoch", - per_device_train_batch_size=batch_size, - per_device_eval_batch_size=batch_size, - dataloader_num_workers=num_workers, - report_to="none", - save_strategy="epoch", - logging_strategy="epoch", - save_total_limit=1, - logging_dir=os.path.join(temp_dir, "logs"), # Make sure to specify a logging directory - load_best_model_at_end=True, # Load the best model when training ends - metric_for_best_model="eval_loss", # Metric to monitor for early stopping - greater_is_better=False, # For loss - seed=self.random_seed, - ) - - # Create the early stopping callback - early_stopping_callback = EarlyStoppingCallback( - early_stopping_patience=5, # Number of epochs with no improvement after which to stop - early_stopping_threshold=1e-5, # Minimum improvement required to consider as improvement - ) - tracking_callback = TrackingCallback() - - # Optimizer and scheduler - optimizer = AdamW(self.ttm.parameters(), lr=learning_rate) - scheduler = OneCycleLR( - optimizer, - learning_rate, - epochs=num_epochs, - steps_per_epoch=math.ceil(len(dset_train) / (batch_size)), - ) - - hf_trainer = Trainer( - model=self.ttm, - args=finetune_forecast_args, - train_dataset=dset_train, - eval_dataset=dset_valid, - callbacks=[early_stopping_callback, tracking_callback], - optimizers=(optimizer, scheduler), - ) - hf_trainer.remove_callback(INTEGRATION_TO_CALLBACK["codecarbon"]) - - # Fine tune - hf_trainer.train() - - # Save model - if save_model: - hf_trainer.save_model(os.path.join(self.out_dir, "ttm_model")) - - def validate( - self, - valid_dataset: TrainingDataset, - batch_size: int = 64, - ): - """(Optionally) Validate. - - Args: - valid_dataset (TrainingDataset): Validation dataset. - batch_size (int, optional): Batch size. Defaults to 64. - - Returns: - flat: Validation loss. - """ - valid_dataset_scaled = self._process_time_series(valid_dataset) - if self.scale: - if self.scaler is None: - self.scaler = StandardScalingGluonTSDataset() - self.scaler.fit(valid_dataset_scaled) - - valid_dataset_scaled = self.scaler.transform(valid_dataset_scaled) - else: - valid_dataset_scaled = valid_dataset - - temp_dir = tempfile.mkdtemp() - dset_valid = TorchDatasetFromGluonTSTrainingDataset( - valid_dataset_scaled, - self.context_length, - self.prediction_length, - last_window_only=True, - ) - - # hf_trainer - hf_trainer = Trainer( - model=self.ttm, - args=TrainingArguments( - output_dir=temp_dir, - per_device_eval_batch_size=batch_size, - seed=self.random_seed, - report_to="none", - eval_accumulation_steps=10, - ), - ) - - # evaluate = zero-shot performance - print("+" * 20, "Zero-shot Test Loss", "+" * 20) - zeroshot_output = hf_trainer.predict(dset_valid) - print(zeroshot_output) - return zeroshot_output["eval_loss"] - - def predict( - self, - test_data_input: InputDataset, - batch_size: int = 64, - ): - """Predict. - - Args: - test_data_input (InputDataset): Test input dataset. - batch_size (int, optional): Batch size. Defaults to 64. - - Returns: - float: Eval loss. - """ - # Standard scale - if self.scale: - # We do not truncate the initial NaNs during testing since it sometimes - # results in extremely short length, and inference fails. - # Hence, in the current implementation the initial NaNs will be converted - # to zeros. - # (not used currently) test_data_input_scaled = self._process_time_series(test_data_input) - - # A new Standard Scaler is defined - # Note: Issue with using the train scaler directly...number of series mismatch! - test_data_input_scaled = copy.deepcopy(test_data_input) - scaler = StandardScalingGluonTSDataset() - scaler.fit(test_data_input_scaled) - test_data_input_scaled = scaler.transform(test_data_input_scaled) - else: - test_data_input_scaled = test_data_input - - while True: - try: - # Generate forecast samples - forecast_samples = [] - for batch in tqdm(batcher(test_data_input_scaled, batch_size=batch_size)): - batch_ttm = {} - adjusted_batch_raw = [] - for idx, entry in enumerate(batch): - # univariate array of shape (time,) - # multivariate array of shape (var, time) - # TTM supports multivariate time series - if len(entry["target"].shape) == 1: - entry["target"] = entry["target"].reshape(1, -1) - - entry_context_length = entry["target"].shape[1] - num_channels = entry["target"].shape[0] - # Pad - if entry_context_length < self.ttm.config.context_length: - padding = torch.zeros( - ( - num_channels, - self.ttm.config.context_length - entry_context_length, - ) - ) - adjusted_entry = torch.cat((padding, torch.tensor(impute_series(entry["target"]))), dim=1) - # observed_mask[idx, :, :(ttm.config.context_length - entry_context_length)] = 0 - # Truncate - elif entry_context_length > self.ttm.config.context_length: - adjusted_entry = torch.tensor( - impute_series(entry["target"][:, -self.ttm.config.context_length :]) - ) - # Take full context - else: - adjusted_entry = torch.tensor(impute_series(entry["target"])) - adjusted_batch_raw.append(adjusted_entry) - - # For TTM channel dimension comes at the end - batch_ttm["past_values"] = torch.stack(adjusted_batch_raw).permute(0, 2, 1).to(self.device) - - if self.prediction_length > TTM_MAX_FORECAST_HORIZON: - recursive_steps = int(np.ceil(self.prediction_length / self.ttm.config.prediction_length)) - predict_outputs = torch.empty(len(batch), 0, num_channels).to(self.device) - with torch.no_grad(): - for i in range(recursive_steps): - model_outputs = self.ttm(**batch_ttm) - batch_ttm["past_values"] = torch.cat( - [ - batch_ttm["past_values"], - model_outputs["prediction_outputs"], - ], - dim=1, - )[:, -self.ttm.config.context_length :, :] - predict_outputs = torch.cat( - [ - predict_outputs, - model_outputs["prediction_outputs"][:, : self.ttm.config.prediction_length, :], - ], - dim=1, - ) - predict_outputs = predict_outputs[:, : self.prediction_length, :] - else: - model_outputs = self.ttm(**batch_ttm) - predict_outputs = model_outputs.prediction_outputs - - # Accumulate all forecasts - forecast_samples.append(predict_outputs.detach().cpu().numpy()) - - # list to np.ndarray - forecast_samples = np.concatenate(forecast_samples) - - if self.scale: - # inverse scale - forecast_samples = scaler.inverse_transform(forecast_samples) - - if forecast_samples.shape[2] == 1: - forecast_samples = np.squeeze(forecast_samples, axis=2) - break - except torch.cuda.OutOfMemoryError: - print(f"OutOfMemoryError at batch_size {batch_size}, reducing to {batch_size // 2}") - batch_size //= 2 - - # Convert forecast samples into gluonts SampleForecast objects - # Array of size (num_samples, prediction_length) (1D case) or - # (num_samples, prediction_length, target_dim) (multivariate case) - sample_forecasts = [] - for item, ts in zip(forecast_samples, test_data_input): - forecast_start_date = ts["start"] + len(ts["target"]) - sample_forecasts.append( - SampleForecast( - item_id=ts["item_id"], - samples=np.expand_dims(item, axis=0), - start_date=forecast_start_date, - ) - ) - - if self.out_dir is None: - self.out_dir = tempfile.mkdtemp() - - if self.plot_test_forecast and self.prediction_length <= TTM_MAX_FORECAST_HORIZON: - # Create torch dataset for plotting - torch_dset_test = TorchDatasetFromGluonTSTestDataset( - gluon_test_input=test_data_input, - gluon_test_label=self.test_data_label, - seq_len=self.ttm.config.context_length, - forecast_len=self.prediction_length, - ) - # Plot random samples - plot_predictions( - dset=torch_dset_test, - model=self.ttm, - plot_dir=f"{self.out_dir}/{self.ds_name}_{self.term}", - channel=0, - plot_context=int(0.5 * self.prediction_length), - ) - - return sample_forecasts diff --git a/extras/gluonts/tests/test_gluonts_data_wrapper.py b/extras/gluonts/tests/test_gluonts_data_wrapper.py deleted file mode 100644 index 0fe609d0..00000000 --- a/extras/gluonts/tests/test_gluonts_data_wrapper.py +++ /dev/null @@ -1,129 +0,0 @@ -# Copyright contributors to the TSFM project -# - -"""Tests get_model""" - -import os -import sys - -import numpy as np -import pandas as pd -import pytest -from gluonts.dataset.common import ListDataset -from gluonts.dataset.split import split - - -sys.path.append(os.path.realpath("../../../")) -from extras.gluonts.data.gluonts_data_wrapper import ( - StandardScalingGluonTSDataset, - TorchDatasetFromGluonTSTestDataset, - TorchDatasetFromGluonTSTrainingDataset, -) - - -@pytest.fixture(scope="module") -def gluonts_data(): - # Step 1: Define the multivariate time series data - num_time_series = 3 # Number of time series - length = 50 # Length of each time series - num_variables = 2 # Number of variables (dimensions) per time series - - # Create random multivariate time series data - time_series_data = [ - { - "item_id": f"ts{i+1}", - "start": pd.Timestamp("2024-01-01"), # Start time for each series - "target": np.random.rand(num_variables, length), # 2D array: (num_variables, length) - } - for i in range(num_time_series) - ] - - # Step 2: Create the ListDataset - freq = "D" # Daily frequency - dataset = ListDataset( - time_series_data, - freq=freq, - one_dim_target=False, - ) - return dataset - - -def test_gluonts_standard_scaling(gluonts_data): - dataset = gluonts_data - - # Split the dataset into train and test - prediction_length = 10 - train_dataset, test_template = split(dataset, offset=-prediction_length) - - # Test shapes - for entry in train_dataset: - assert entry["target"].shape == (2, 40) - - test_dataset = test_template.generate_instances( - prediction_length=prediction_length, - ) - test_dataset_input = test_dataset.input - test_dataset_label = test_dataset.label - # Test shapes - for entry in test_dataset_input: - assert entry["target"].shape == (2, 40) - for entry in test_dataset_label: - assert entry["target"].shape == (2, 10) - - # Test scaler - scaler = StandardScalingGluonTSDataset() - scaler.fit(train_dataset) - train_dataset_scaled = scaler.transform(train_dataset) - test_dataset_scaled = scaler.transform(test_dataset_input) - - # Test scaling - for entry in train_dataset_scaled: - np.testing.assert_almost_equal(entry["target"].mean(axis=1), np.array([0.0, 0.0]), decimal=4) - np.testing.assert_almost_equal(entry["target"].std(axis=1), np.array([1.0, 1.0]), decimal=4) - - for entry in test_dataset_scaled: - np.testing.assert_almost_equal(entry["target"].mean(axis=1), np.array([0.0, 0.0]), decimal=4) - np.testing.assert_almost_equal(entry["target"].std(axis=1), np.array([1.0, 1.0]), decimal=4) - - # inverse - test_label_scaled = scaler.transform(test_dataset_label) - Y = [] - for entry in test_label_scaled: - Y.append(entry["target"].T) - Y = np.array(Y) - Y_inv = scaler.inverse_transform(Y) - - Y_org = [] - for entry in test_dataset_label: - Y_org.append(entry["target"].T) - Y_org = np.array(Y_org) - - np.testing.assert_almost_equal(Y_inv.mean(), Y_org.mean(), decimal=4) - - -def test_pytorch_data_wrappers(gluonts_data): - dataset = gluonts_data - - # Split the dataset into train and test - prediction_length = 10 - train_dataset, test_template = split(dataset, offset=-prediction_length) - test_dataset = test_template.generate_instances( - prediction_length=prediction_length, - ) - test_dataset_input = test_dataset.input - test_dataset_label = test_dataset.label - - torch_train_dset = TorchDatasetFromGluonTSTrainingDataset(train_dataset, seq_len=20, forecast_len=5) - assert torch_train_dset[1]["past_values"].shape == (20, 2) - assert torch_train_dset[1]["future_values"].shape == (5, 2) - - torch_train_dset = TorchDatasetFromGluonTSTrainingDataset(train_dataset, seq_len=35, forecast_len=5) - assert torch_train_dset[1]["past_values"].shape == (35, 2) - assert torch_train_dset[1]["future_values"].shape == (5, 2) - assert len(torch_train_dset) == 3 - - torch_test_dset = TorchDatasetFromGluonTSTestDataset( - gluon_test_input=test_dataset_input, gluon_test_label=test_dataset_label, seq_len=20, forecast_len=5 - ) - assert torch_test_dset[0]["past_values"].shape == (20, 2) - assert torch_test_dset[0]["future_values"].shape == (5, 2) diff --git a/extras/gluonts/tests/test_gluonts_predictor.py b/extras/gluonts/tests/test_gluonts_predictor.py deleted file mode 100644 index 1f9f4e58..00000000 --- a/extras/gluonts/tests/test_gluonts_predictor.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright contributors to the TSFM project -# - -"""Tests get_model""" - -import os -import sys - -import numpy as np -import pandas as pd -import pytest -from gluonts.dataset.common import ListDataset - - -sys.path.append(os.path.realpath("../../")) -from extras.gluonts.models.tinytimemixer import TTMGluonTSPredictor - - -@pytest.fixture(scope="module") -def gluonts_data_with_nan(): - # Step 1: Define the multivariate time series data - num_time_series = 3 # Number of time series - num_variables = 2 # Number of variables (dimensions) per time series - - # Create random multivariate time series data - time_series_data = [ - { - "item_id": f"ts{i+1}", - "start": pd.Timestamp("2024-01-01"), # Start time for each series - "target": np.concatenate( - ( - np.array([[np.nan, np.nan, np.nan, np.nan], [0, 1, np.nan, 2]]), - np.random.rand(num_variables, 600), - np.array([[np.nan, np.nan, np.nan, np.nan], [np.nan, 1, np.nan, 2]]), - np.random.rand(num_variables, 4), - ), - axis=1, - ), # 2D array: (num_variables, length) - } - for i in range(num_time_series) - ] - - # Step 2: Create the ListDataset - freq = "D" # Daily frequency - dataset = ListDataset( - time_series_data, - freq=freq, - one_dim_target=False, - ) - return dataset - - -def test_ttm_gluonts_predictor(gluonts_data_with_nan): - dataset = gluonts_data_with_nan - predictor = TTMGluonTSPredictor(context_length=512, prediction_length=96) - forecasts = predictor.predict(dataset) - assert forecasts[0].samples.shape == (1, 96, 2) diff --git a/extras/utils/ddp_utils.py b/extras/utils/ddp_utils.py deleted file mode 100644 index c414602f..00000000 --- a/extras/utils/ddp_utils.py +++ /dev/null @@ -1,24 +0,0 @@ -import os -from datetime import timedelta - -import torch - - -def init_ddp(timeout=600): - local_rank = int(os.environ.get("LOCAL_RANK")) - world_size = int(os.environ.get("WORLD_SIZE")) - rank = int(os.environ.get("RANK")) - - torch.cuda.set_device(local_rank) - torch.distributed.init_process_group( - "nccl", - init_method="env://", - world_size=world_size, - rank=rank, - timeout=timedelta(seconds=timeout), - ) - - -def is_rank_0(): - rank = torch.distributed.get_rank() - return rank == 0 diff --git a/tests/toolkit/test_get_model.py b/tests/toolkit/test_get_model.py index 9d377169..76eda986 100644 --- a/tests/toolkit/test_get_model.py +++ b/tests/toolkit/test_get_model.py @@ -68,7 +68,22 @@ def test_get_model(): assert model.config.context_length == cl mp = "ibm/ttm-research-r2" - for cl in range(512, 5000, 500): + for cl in range(1, 2000, 500): + for fl in range(1, 900, 90): + model = get_model(model_path=mp, context_length=cl, prediction_length=fl) + if model.config.prediction_filter_length is not None: + assert model.config.prediction_filter_length == fl + + mp = "ibm-granite/granite-timeseries-ttm-r2" + for cl in range(1, 2000, 500): + for fl in range(1, 900, 90): + model = get_model(model_path=mp, context_length=cl, prediction_length=fl) + if model.config.prediction_filter_length is not None: + assert model.config.prediction_filter_length == fl + + mp = "ibm-granite/granite-timeseries-ttm-r1" + for cl in range(512, 2000, 500): for fl in range(1, 720, 90): model = get_model(model_path=mp, context_length=cl, prediction_length=fl) - assert model.config.prediction_filter_length == fl + if model.config.prediction_filter_length is not None: + assert model.config.prediction_filter_length == fl diff --git a/tsfm_public/toolkit/get_model.py b/tsfm_public/toolkit/get_model.py index bb76c145..ace81075 100644 --- a/tsfm_public/toolkit/get_model.py +++ b/tsfm_public/toolkit/get_model.py @@ -48,7 +48,7 @@ def get_model( context_length: int = None, prediction_length: int = None, freq_prefix_tuning: bool = None, - force_return: bool = False, + force_return: bool = True, **kwargs, ): """ @@ -94,33 +94,27 @@ def get_model( with open(os.path.join(config_dir, "ttm.yaml"), "r") as file: model_revisions = yaml.safe_load(file) - if prediction_length <= 96: - selected_prediction_length = 96 - elif prediction_length <= 192: - selected_prediction_length = 192 - elif prediction_length <= 336: - selected_prediction_length = 336 - elif prediction_length <= 720: - selected_prediction_length = 720 - else: + max_supported_horizon = SUPPORTED_LENGTHS[model_path_type]["FL"][-1] + if prediction_length > max_supported_horizon: if force_return: - selected_prediction_length = 720 + selected_prediction_length = max_supported_horizon LOGGER.warning( - "The requested forecast horizon is greater than the maximum supported horizon (720).\n\ - Returning TTM model with horizon 720 since `force_return=True`." + f"The requested forecast horizon is greater than the maximum supported horizon ({max_supported_horizon}). Returning TTM model with horizon {max_supported_horizon} since `force_return=True`." ) else: - raise ValueError("Currently supported maximum prediction_length = 720") + raise ValueError(f"Currently supported maximum prediction_length = {max_supported_horizon}") + else: + for h in SUPPORTED_LENGTHS[model_path_type]["FL"]: + if prediction_length <= h: + selected_prediction_length = h + break - LOGGER.info(f"Selected prediction_length = {selected_prediction_length}") + LOGGER.info(f"Selected TTM `prediction_length` = {selected_prediction_length}") - if selected_prediction_length != prediction_length: + if selected_prediction_length > prediction_length: prediction_filter_length = prediction_length LOGGER.warning( - f"Requested `prediction_length` ({prediction_length}) is not exactly equal to any of the available TTM prediction lengths.\n\ - Hence, TTM will forecast using the `prediction_filter_length` argument to provide the requested prediction length.\n\ - Supported context lengths (CL) and forecast/prediction lengths (FL) for Model Card: {model_path} are\n\ - {SUPPORTED_LENGTHS[model_path_type]}" + f"Requested `prediction_length` ({prediction_length}) is not exactly equal to any of the available TTM prediction lengths. Hence, TTM will forecast using the `prediction_filter_length` argument to provide the requested prediction length. Supported context lengths (CL) and forecast/prediction lengths (FL) for Model Card: {model_path} are {SUPPORTED_LENGTHS[model_path_type]}" ) # Choose closest context length @@ -137,13 +131,13 @@ def get_model( if selected_context_length is None: if force_return: selected_context_length = available_context_lens[-1] - LOGGER.warning(f"Requested context length is too short. Requested = {context_length}.\n\ - Available lengths for model_type = {model_path_type} are: {available_context_lens}.\n\ - Returning the shortest context length model possible since `force_return=True`.") + LOGGER.warning( + f"Requested context length is too short. Requested = {context_length}. Available lengths for model_type = {model_path_type} are: {available_context_lens}. Returning the shortest context length model possible since `force_return=True`. Data needs to be handled properly, and it can affect the performance!" + ) else: - raise ValueError(f"Requested context length is too short. Requested = {context_length}.\n\ - Available lengths for model_type = {model_path_type} are: {available_context_lens}.\n\ - To return the shortest context length model possible, set `force_return=True`.") + raise ValueError( + f"Requested context length is too short. Requested = {context_length}. Available lengths for model_type = {model_path_type} are: {available_context_lens}. To return the shortest context length model possible, set `force_return=True`." + ) if freq_prefix_tuning is None: # Default model preference (freq / nofreq) From 69ed4fd3d22e2125c434a03fa210b305079d942e Mon Sep 17 00:00:00 2001 From: Arindam Jati Date: Thu, 5 Dec 2024 12:42:51 -0500 Subject: [PATCH 11/23] revert toml and visualization functions --- pyproject.toml | 5 ++--- tsfm_public/toolkit/visualization.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d21fb799..60a7c490 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,7 @@ packages = ["tsfm_public", "tsfmhfdemos"] [project.optional-dependencies] -all = ["tsfm_public[notebooks,external,testing,dev]"] +all = ["tsfm_public[notebooks,testing,dev]"] notebooks = [ "jupyter", @@ -42,8 +42,7 @@ notebooks = [ "kaleido", "tensorboard", ] -external = ["tsfm_public[notebooks]", "gluonts"] -testing = ["pytest", "tsfm_public[external]", "parameterized"] +testing = ["pytest", "tsfm_public[notebooks]", "parameterized"] dev = ["pre-commit", "tsfm_public[testing]", "ruff==0.4.4"] # ogv deployments will already have jupyter diff --git a/tsfm_public/toolkit/visualization.py b/tsfm_public/toolkit/visualization.py index 93f6623c..658a469e 100644 --- a/tsfm_public/toolkit/visualization.py +++ b/tsfm_public/toolkit/visualization.py @@ -309,7 +309,7 @@ def plot_predictions( if k in signature_keys: random_samples[k] = torch.stack([dset[i][k] for i in indices]).to(device=device) output = model(**random_samples) - predictions_subset = output.prediction_outputs[:, :, channel].cpu().numpy() + predictions_subset = output.prediction_outputs[:, :, channel].squeeze().cpu().numpy() prediction_length = predictions_subset.shape[1] using_pipeline = False plot_test_data = True From 62ea27fb18e7a23ad98ef9cabef9e28e1fc825df Mon Sep 17 00:00:00 2001 From: Stuart Siegel <12914116+ssiegel95@users.noreply.github.com> Date: Thu, 5 Dec 2024 20:52:03 -0500 Subject: [PATCH 12/23] add optional verbose payload dumps --- services/inference/tests/test_inference_lib.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/services/inference/tests/test_inference_lib.py b/services/inference/tests/test_inference_lib.py index 8a61c55a..a93b38aa 100644 --- a/services/inference/tests/test_inference_lib.py +++ b/services/inference/tests/test_inference_lib.py @@ -2,6 +2,9 @@ # import copy +import json +import os +import tempfile from datetime import timedelta import numpy as np @@ -86,8 +89,17 @@ def _basic_result_checks(results: PredictOutput, df: pd.DataFrame): def test_forecast_with_good_data(ts_data_base: pd.DataFrame, forecasting_input_base: ForecastingInferenceInput): input = forecasting_input_base + model_id = input.model_id df = copy.deepcopy(ts_data_base) input.data = df.to_dict(orient="list") + + # useful for generating sample payload files + if int(os.environ.get("TSFM_TESTS_DO_VERBOSE_DUMPS", "0")) == 1: + with open(f"{tempfile.gettempdir()}/{model_id}.payload.json", "w") as out: + foo = copy.deepcopy(df) + foo["date"] = foo["date"].apply(lambda x: x.isoformat()) + json.dump(foo.to_dict(orient="list"), out) + runtime: InferenceRuntime = InferenceRuntime(config=config) po: PredictOutput = runtime.forecast(input=input) results = pd.DataFrame.from_dict(po.results[0]) From bfa6535687b13291b829fe7e2b85af9d5f0fd3ee Mon Sep 17 00:00:00 2001 From: Arindam Jati Date: Thu, 5 Dec 2024 22:02:10 -0500 Subject: [PATCH 13/23] exception -> valueerror --- tsfm_public/toolkit/get_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tsfm_public/toolkit/get_model.py b/tsfm_public/toolkit/get_model.py index ace81075..bfb9a71f 100644 --- a/tsfm_public/toolkit/get_model.py +++ b/tsfm_public/toolkit/get_model.py @@ -148,8 +148,8 @@ def get_model( else: freq_prefix = None else: - raise Exception( - "In current implementation, set freq_prefix_tuning to None for automatic model selection accordingly.." + raise ValueError( + "In the current implementation, set `freq_prefix_tuning` to None for automatic model selection accordingly." ) if freq_prefix_tuning: freq_prefix = "freq" From 8354f6a920c9d76b8ef59c3b51f6410bc1599abb Mon Sep 17 00:00:00 2001 From: Stuart Siegel <12914116+ssiegel95@users.noreply.github.com> Date: Fri, 6 Dec 2024 09:14:45 -0500 Subject: [PATCH 14/23] we can't resolve to a single directory here, need to scan them in load --- services/inference/tsfminference/__init__.py | 21 ++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/services/inference/tsfminference/__init__.py b/services/inference/tsfminference/__init__.py index 93c7af2c..6de6f50c 100644 --- a/services/inference/tsfminference/__init__.py +++ b/services/inference/tsfminference/__init__.py @@ -38,7 +38,20 @@ ) # use TSFM_MODEL_DIR preferentially. If not set, use HF_HOME or the system tempdir if that's not set. -TSFM_MODEL_DIR: Path = Path(os.environ.get("TSFM_MODEL_DIR", os.environ.get("HF_HOME", tempfile.gettempdir()))) - -if not TSFM_MODEL_DIR.exists(): - raise Exception(f"TSFM_MODEL_DIR {TSFM_MODEL_DIR} does not exist.") +TSFM_MODEL_DIR: str = os.environ.get("TSFM_MODEL_DIR", os.environ.get("HF_HOME", tempfile.gettempdir())) + +# basic checks +# make sure at least one of them is a valid directory +# make sure it's readable as well +_amodeldir_found = next( + ( + adir + for adir in (Path(p) for p in TSFM_MODEL_DIR.split(":")) + if adir.exists() and adir.is_dir() and os.access(adir, os.R_OK) + ), + None, +) +if not _amodeldir_found and not TSFM_ALLOW_LOAD_FROM_HF_HUB: + raise Exception( + f"None of the values given in TSFM_MODEL_DIR {TSFM_MODEL_DIR} are an existing and readable directory." + ) From 96e5bb446a2b280cf5bf2f804ed562a0254ac363 Mon Sep 17 00:00:00 2001 From: Stuart Siegel <12914116+ssiegel95@users.noreply.github.com> Date: Fri, 6 Dec 2024 09:15:24 -0500 Subject: [PATCH 15/23] add additional directory to TSFM_MODEL_DIR --- services/inference/Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/services/inference/Makefile b/services/inference/Makefile index e8838bd5..c22abcc0 100644 --- a/services/inference/Makefile +++ b/services/inference/Makefile @@ -16,7 +16,7 @@ create_prometheus_metrics_dir: start_service_local: create_prometheus_metrics_dir boilerplate PROMETHEUS_MULTIPROC_DIR=./prometheus_metrics \ TSFM_PYTHON_LOGGING_LEVEL="ERROR" \ - TSFM_MODEL_DIR=./mytest-tsfm \ + TSFM_MODEL_DIR=./foobaz:./mytest-tsfm \ TSFM_ALLOW_LOAD_FROM_HF_HUB=1 \ python -m gunicorn \ -w 1 \ From 7d8d3be56275d44bc8c1fc018548338e3e32467c Mon Sep 17 00:00:00 2001 From: "Stuart A. Siegel" <12914116+ssiegel95@users.noreply.github.com> Date: Fri, 6 Dec 2024 13:30:13 -0500 Subject: [PATCH 16/23] model path resolver --- services/inference/tsfminference/dirutil.py | 41 +++++++++++++++++++++ 1 file changed, 41 insertions(+) create mode 100644 services/inference/tsfminference/dirutil.py diff --git a/services/inference/tsfminference/dirutil.py b/services/inference/tsfminference/dirutil.py new file mode 100644 index 00000000..2c746108 --- /dev/null +++ b/services/inference/tsfminference/dirutil.py @@ -0,0 +1,41 @@ +# THIS FILE IS COPIED FROM THE BOILERPLATE DIRECTORY, +# DO NOT EDIT IT OR YOU WILL LOSE YOUR CHANGES. +# MAKE CHANGES IN THE TOP LEVEL services/boilerplate DIRECTORY. +# +# +# +"""Utilities for directory operations.""" + +import os +from pathlib import Path + + +def resolve_model_path(search_path: str, model_id: str) -> Path: + """Find the first path under search_path for model_id. All entries in + search_path must be: + * an existing directory + * must be readable by the current process + + Args: + search_path (str): A unix-like ":" separated list of directories such a "dir1:dir2" + model_id (str): a model_id (which is really just a subdirectory under dir1 or dir2) + + Returns: + Path: the first matching path, None if no path is fount. + """ + + _amodeldir_found = next( + ( + adir + for adir in (Path(p) for p in search_path.split(":")) + if adir.exists() + and adir.is_dir() + and os.access(adir, os.R_OK) + and (adir / model_id).exists() + and os.access(adir / model_id, os.R_OK) + ), + None, + ) + if not _amodeldir_found: + return None + return _amodeldir_found / model_id From 82ab9871b32424e580b8497fb3368f5c8aa7964d Mon Sep 17 00:00:00 2001 From: "Stuart A. Siegel" <12914116+ssiegel95@users.noreply.github.com> Date: Fri, 6 Dec 2024 13:32:05 -0500 Subject: [PATCH 17/23] ignore prometheus metrics dir --- services/inference/.gitignore | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/services/inference/.gitignore b/services/inference/.gitignore index e3e53154..922cc22c 100644 --- a/services/inference/.gitignore +++ b/services/inference/.gitignore @@ -1,3 +1 @@ -# These version placeholders will be replaced later during substitution. -__version__ = "0.0.0" -__version_tuple__ = (0, 0, 0) +prometheus_metrics From 1a1b75ad5b347b2da0b27aa96ef70728d613c035 Mon Sep 17 00:00:00 2001 From: "Stuart A. Siegel" <12914116+ssiegel95@users.noreply.github.com> Date: Fri, 6 Dec 2024 13:32:23 -0500 Subject: [PATCH 18/23] model dir resolver --- services/boilerplate/dirutil.py | 35 +++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 services/boilerplate/dirutil.py diff --git a/services/boilerplate/dirutil.py b/services/boilerplate/dirutil.py new file mode 100644 index 00000000..c2a68c4d --- /dev/null +++ b/services/boilerplate/dirutil.py @@ -0,0 +1,35 @@ +"""Utilities for directory operations.""" + +import os +from pathlib import Path + + +def resolve_model_path(search_path: str, model_id: str) -> Path: + """Find the first path under search_path for model_id. All entries in + search_path must be: + * an existing directory + * must be readable by the current process + + Args: + search_path (str): A unix-like ":" separated list of directories such a "dir1:dir2" + model_id (str): a model_id (which is really just a subdirectory under dir1 or dir2) + + Returns: + Path: the first matching path, None if no path is fount. + """ + + _amodeldir_found = next( + ( + adir + for adir in (Path(p) for p in search_path.split(":")) + if adir.exists() + and adir.is_dir() + and os.access(adir, os.R_OK) + and (adir / model_id).exists() + and os.access(adir / model_id, os.R_OK) + ), + None, + ) + if not _amodeldir_found: + return None + return _amodeldir_found / model_id From 06eb16b7e405d60788c93653dac584807c34bd5b Mon Sep 17 00:00:00 2001 From: "Stuart A. Siegel" <12914116+ssiegel95@users.noreply.github.com> Date: Fri, 6 Dec 2024 13:32:37 -0500 Subject: [PATCH 19/23] use model path resolver --- services/inference/tsfminference/inference.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/services/inference/tsfminference/inference.py b/services/inference/tsfminference/inference.py index 7a50650c..650e464e 100644 --- a/services/inference/tsfminference/inference.py +++ b/services/inference/tsfminference/inference.py @@ -17,6 +17,7 @@ from . import TSFM_ALLOW_LOAD_FROM_HF_HUB, TSFM_MODEL_DIR from .constants import API_VERSION from .dataframe_checks import check +from .dirutil import resolve_model_path from .errors import error_message from .inference_payloads import ForecastingInferenceInput, ForecastingMetadataInput, PredictOutput from .service_handler import ForecastingServiceHandler @@ -47,8 +48,8 @@ def add_routes(self, app): app.include_router(self.router) def _modelspec(self, model_id: str): - model_path = TSFM_MODEL_DIR / model_id - if not model_path.exists(): + model_path = resolve_model_path(TSFM_MODEL_DIR, model_id) + if not model_path: raise HTTPException(status_code=404, detail=f"model {model_id} not found.") handler, e = ForecastingServiceHandler.load(model_id=model_id, model_path=model_path) if handler.handler_config: @@ -84,16 +85,16 @@ def forecast(self, input: ForecastingInferenceInput): return answer def _forecast_common(self, input_payload: ForecastingInferenceInput) -> PredictOutput: - model_path = TSFM_MODEL_DIR / input_payload.model_id + model_path = resolve_model_path(TSFM_MODEL_DIR, input_payload.model_id) - if not model_path.is_dir(): + if not model_path: LOGGER.info(f"Could not find model at path: {model_path}") if TSFM_ALLOW_LOAD_FROM_HF_HUB: model_path = input_payload.model_id LOGGER.info(f"Using HuggingFace Hub: {model_path}") else: return None, RuntimeError( - f"Could not load model {input_payload.model_id} from {TSFM_MODEL_DIR.as_posix()}. If trying to load directly from the HuggingFace Hub please ensure that `TSFM_ALLOW_LOAD_FROM_HF_HUB=1`" + f"Could not load model {input_payload.model_id} from {TSFM_MODEL_DIR}. If trying to load directly from the HuggingFace Hub please ensure that `TSFM_ALLOW_LOAD_FROM_HF_HUB=1`" ) handler, e = ForecastingServiceHandler.load(model_id=input_payload.model_id, model_path=model_path) From eddab4e96a6a403724f919163b4d0b167dee6077 Mon Sep 17 00:00:00 2001 From: "Stuart A. Siegel" <12914116+ssiegel95@users.noreply.github.com> Date: Fri, 6 Dec 2024 13:32:46 -0500 Subject: [PATCH 20/23] test model path resolver --- services/tests/test_dirutil.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) create mode 100644 services/tests/test_dirutil.py diff --git a/services/tests/test_dirutil.py b/services/tests/test_dirutil.py new file mode 100644 index 00000000..a9d9414f --- /dev/null +++ b/services/tests/test_dirutil.py @@ -0,0 +1,17 @@ +import os +import tempfile +from pathlib import Path + +from tsfminference.dirutil import resolve_model_path + + +def test_resolve_model_path(): + with tempfile.TemporaryDirectory() as dir1: + with tempfile.TemporaryDirectory() as dir2: + dirpath = f"{dir1}:{dir2}" + os.mkdir(Path(dir1) / "amodel") + assert resolve_model_path(dirpath, "amodel") == Path(dir1) / "amodel" + assert resolve_model_path(dirpath, "foobar") is None + assert resolve_model_path("fzbatt:zap", "amodel") is None + os.mkdir(Path(dir2) / "anewmodel") + assert resolve_model_path(dirpath, "anewmodel") == Path(dir2) / "anewmodel" From f45c0b738ac614b3599eacc0abdefe3054cd39ce Mon Sep 17 00:00:00 2001 From: Stuart Siegel <12914116+ssiegel95@users.noreply.github.com> Date: Mon, 9 Dec 2024 14:54:06 -0500 Subject: [PATCH 21/23] boilerplate code --- services/inference/tsfminference/dirutil.py | 41 --------------------- 1 file changed, 41 deletions(-) delete mode 100644 services/inference/tsfminference/dirutil.py diff --git a/services/inference/tsfminference/dirutil.py b/services/inference/tsfminference/dirutil.py deleted file mode 100644 index 2c746108..00000000 --- a/services/inference/tsfminference/dirutil.py +++ /dev/null @@ -1,41 +0,0 @@ -# THIS FILE IS COPIED FROM THE BOILERPLATE DIRECTORY, -# DO NOT EDIT IT OR YOU WILL LOSE YOUR CHANGES. -# MAKE CHANGES IN THE TOP LEVEL services/boilerplate DIRECTORY. -# -# -# -"""Utilities for directory operations.""" - -import os -from pathlib import Path - - -def resolve_model_path(search_path: str, model_id: str) -> Path: - """Find the first path under search_path for model_id. All entries in - search_path must be: - * an existing directory - * must be readable by the current process - - Args: - search_path (str): A unix-like ":" separated list of directories such a "dir1:dir2" - model_id (str): a model_id (which is really just a subdirectory under dir1 or dir2) - - Returns: - Path: the first matching path, None if no path is fount. - """ - - _amodeldir_found = next( - ( - adir - for adir in (Path(p) for p in search_path.split(":")) - if adir.exists() - and adir.is_dir() - and os.access(adir, os.R_OK) - and (adir / model_id).exists() - and os.access(adir / model_id, os.R_OK) - ), - None, - ) - if not _amodeldir_found: - return None - return _amodeldir_found / model_id From 1e67a11d397e7aa7a683e741555e660f8bdccffe Mon Sep 17 00:00:00 2001 From: Stuart Siegel <12914116+ssiegel95@users.noreply.github.com> Date: Mon, 9 Dec 2024 14:54:47 -0500 Subject: [PATCH 22/23] ignore dirutil.py --- services/inference/tsfminference/.gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/services/inference/tsfminference/.gitignore b/services/inference/tsfminference/.gitignore index 15667315..fcaf21f6 100644 --- a/services/inference/tsfminference/.gitignore +++ b/services/inference/tsfminference/.gitignore @@ -2,3 +2,4 @@ inference_payloads.py errors.py hfutil.py dataframe_checks.py +dirutil.py From 35870fdbb36ec71be0082f32ee03b1b5350abdcb Mon Sep 17 00:00:00 2001 From: Stuart Siegel <12914116+ssiegel95@users.noreply.github.com> Date: Mon, 9 Dec 2024 15:05:37 -0500 Subject: [PATCH 23/23] automate maintenance of .gitignore --- services/finetuning/Makefile | 3 +++ services/finetuning/tsfmfinetuning/.gitignore | 7 +++++-- services/inference/Makefile | 3 +++ services/inference/tsfminference/.gitignore | 7 ++++--- 4 files changed, 15 insertions(+), 5 deletions(-) diff --git a/services/finetuning/Makefile b/services/finetuning/Makefile index 0688d4b2..d0223145 100644 --- a/services/finetuning/Makefile +++ b/services/finetuning/Makefile @@ -2,10 +2,13 @@ CONTAINER_BUILDER ?= docker # copies boilerplate code to suitable locations boilerplate: + rm tsfmfinetuning/.gitignore || true + echo "# THIS FILE IS AUTOMATICALLY GENERATED, YOUR CHANGES WILL BE OVERWRITTEN" > tsfmfinetuning/.gitignore for f in ../boilerplate/*.py; do \ echo $$f; \ cat ../boilerplate/warning.txt > tsfmfinetuning/$$(basename $$f); \ cat $$f>>tsfmfinetuning/$$(basename $$f); \ + echo $$(basename $$f) >> tsfmfinetuning/.gitignore; \ done image: diff --git a/services/finetuning/tsfmfinetuning/.gitignore b/services/finetuning/tsfmfinetuning/.gitignore index a6f2b475..fc766639 100644 --- a/services/finetuning/tsfmfinetuning/.gitignore +++ b/services/finetuning/tsfmfinetuning/.gitignore @@ -1,3 +1,6 @@ -inference_payloads.py -hfutil.py +# THIS FILE IS AUTOMATICALLY GENERATED, YOUR CHANGES WILL BE OVERWRITTEN +dataframe_checks.py +dirutil.py errors.py +hfutil.py +inference_payloads.py diff --git a/services/inference/Makefile b/services/inference/Makefile index c22abcc0..cbc95df3 100644 --- a/services/inference/Makefile +++ b/services/inference/Makefile @@ -2,10 +2,13 @@ CONTAINER_BUILDER ?= docker # copies boilerplate code to suitable locations boilerplate: + rm tsfminference/.gitignore || true + echo "# THIS FILE IS AUTOMATICALLY GENERATED, YOUR CHANGES WILL BE OVERWRITTEN" > tsfminference/.gitignore for f in ../boilerplate/*.py; do \ echo $$f; \ cat ../boilerplate/warning.txt > tsfminference/$$(basename $$f); \ cat $$f>>tsfminference/$$(basename $$f); \ + echo $$(basename $$f) >> tsfminference/.gitignore; \ done create_prometheus_metrics_dir: diff --git a/services/inference/tsfminference/.gitignore b/services/inference/tsfminference/.gitignore index fcaf21f6..fc766639 100644 --- a/services/inference/tsfminference/.gitignore +++ b/services/inference/tsfminference/.gitignore @@ -1,5 +1,6 @@ -inference_payloads.py -errors.py -hfutil.py +# THIS FILE IS AUTOMATICALLY GENERATED, YOUR CHANGES WILL BE OVERWRITTEN dataframe_checks.py dirutil.py +errors.py +hfutil.py +inference_payloads.py