This repository provides an implementation of the Time Series Foundation Model (TimesFM) using PyTorch, based on the original model available here.
The purpose of this repository is to offer a PyTorch variant that can load checkpoints from the JAX version and enable fine-tuning. Given the scope of my personal needs and available resources, this implementation includes only the essential components required to operate the model effectively.
Add dockerfiles in timesfm_torch/docker
for building container that can run TimesFM JAX version. If you have trouble converting the JAX weights to PyTorch weights, try converting them in docker container. See timesfm_torch/docker/README.md for detail.
- Capability to load JAX checkpoints into a PyTorch model with the same architecture.
- Core components constituting the TimesFM model.
- Output equivalence with the JAX version under specific conditions.(Numerical error exists)
- Padding handling (assumes no padding during inference).
- Support for variable context and horizon lengths (easy to add).
- Different frequency embeddings (same above).
- The mean and standard deviation are computed across the entire time series rather than just the first patch.
Install the package using pip:
pip install -e .
Navigate to the utility directory and run the conversion script:
cd timesfm_torch/timesfm_torch/utils
python convert_ckpt.py
This process will generate PyTorch checkpoints in timesfm_torch/timesfm_torch/ckpt
.
By default, the model loads the checkpoint during initialization. The forward()
method replicates the functionality of the PatchedTimeSeriesDecoder.__call__()
method in the JAX version, maintaining the same input and output shapes. Note that the forward()
method does not handle padding and only requires the input time series.
from timesfm_torch.model.timesfm import TimesFm
input_ts = torch.rand((32, 512)).to('cuda') # Input shape: (batch_size, context_len)
timesfm = TimesFm(context_len=512)
timesfm.load_from_checkpoint(ckpt_dir=f"timesfm_torch/timesfm_torch/ckpt")
output_ts = timesfm(input_ts) # Output shape: (batch_size, patch_num, horizon_len, num_outputs)