Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/api/models/foundation/models.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@
members:
- Sundial

::: timecopilot.models.foundation.t0
options:
members:
- T0

::: timecopilot.models.foundation.tabpfn
options:
members:
Expand Down
1 change: 1 addition & 0 deletions docs/model-hub.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ TimeCopilot provides a unified interface to state-of-the-art foundation models f
- [Moirai](api/models/foundation/models.md#timecopilot.models.foundation.moirai) ([arXiv:2402.02592](https://arxiv.org/abs/2402.02592))
- [PatchTST-FM](api/models/foundation/models.md#timecopilot.models.foundation.patchtst_fm) ([arXiv:2602.06909](https://arxiv.org/abs/2602.06909))
- [Sundial](api/models/foundation/models.md#timecopilot.models.foundation.sundial) ([arXiv:2502.00816](https://arxiv.org/pdf/2502.00816))
- [T0](api/models/foundation/models.md#timecopilot.models.foundation.t0) ([model card](https://huggingface.co/theforecastingcompany/t0-alpha))
- [TabPFN](api/models/foundation/models.md#timecopilot.models.foundation.tabpfn) ([arXiv:2501.02945](https://arxiv.org/abs/2501.02945))
- [TiRex](api/models/foundation/models.md#timecopilot.models.foundation.tirex) ([arXiv:2505.23719](https://arxiv.org/abs/2505.23719))
- [TimeGPT](api/models/foundation/models.md#timecopilot.models.foundation.timegpt) ([arXiv:2310.03589](https://arxiv.org/abs/2310.03589))
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ dependencies = [
"statsforecast>=2.0.2",
"tabpfn-time-series==1.0.3 ; python_full_version < '3.13'",
"tensorboard>=2.20.0",
"tfc-t0>=0.1.2 ; python_full_version >= '3.11' and python_full_version < '3.14'",
"timecopilot-chronos-forecasting>=0.2.1",
"timecopilot-granite-tsfm>=0.1.2",
"timecopilot-timesfm>=0.2.1",
Expand Down
5 changes: 5 additions & 0 deletions tests/models/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,11 @@ def disable_mps_session(monkeypatch):

models.append(TiRex())

if (3, 11) <= sys.version_info < (3, 14):
from timecopilot.models.foundation.t0 import T0

models.append(T0(context_length=256, batch_size=2))

if sys.version_info < (3, 13):
from tabpfn_time_series import TabPFNMode

Expand Down
10 changes: 10 additions & 0 deletions tests/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,16 @@ def test_tirex_import_fails():
assert "requires Python >= 3.11" in str(excinfo.value)


@pytest.mark.skipif(
(3, 11) <= sys.version_info < (3, 14),
reason="T0 requires Python >= 3.11 and < 3.14",
)
def test_t0_import_fails():
with pytest.raises(ImportError) as excinfo:
from timecopilot.models.foundation.t0 import T0 # noqa: F401
assert "requires Python >= 3.11 and < 3.14" in str(excinfo.value)


@pytest.mark.skipif(
sys.version_info < (3, 13),
reason="Sundial requires Python < 3.13",
Expand Down
195 changes: 195 additions & 0 deletions timecopilot/models/foundation/t0.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
import sys
from contextlib import contextmanager

if sys.version_info < (3, 11) or sys.version_info >= (3, 14):
raise ImportError("T0 requires Python >= 3.11 and < 3.14")

import numpy as np
import pandas as pd
import torch
from t0 import T0Forecaster
from tqdm import tqdm

from ..utils.forecaster import Forecaster, QuantileConverter
from .utils import TimeSeriesDataset


class T0(Forecaster):
"""
T0 is an open-weights time series foundation model from
[The Forecasting Company](https://theforecastingcompany.com/). It is a
decoder-style patch transformer that alternates time and covariate
attention layers, producing probabilistic multi-horizon quantile
forecasts. It decodes up to 1,024 timesteps in a single forward pass and
falls back on autoregressive rollout for longer horizons. See the
[model card](https://huggingface.co/theforecastingcompany/t0-alpha)
for more details.
"""

def __init__(
self,
repo_id: str = "theforecastingcompany/t0-alpha",
context_length: int = 4096,
batch_size: int = 16,
alias: str = "T0",
):
# ruff: noqa: E501
"""
Args:
repo_id (str, optional): The Hugging Face Hub model ID or local path to
load the T0 model from. Defaults to "theforecastingcompany/t0-alpha".
See the full list of models at
[Hugging Face](https://huggingface.co/theforecastingcompany).
context_length (int, optional): Maximum context length (input window
size) for the model. Series longer than this are truncated to the
most recent `context_length` observations. Defaults to 4096.
batch_size (int, optional): Batch size to use for inference. Defaults
to 16. Adjust based on available memory.
alias (str, optional): Name to use for the model in output DataFrames
and logs. Defaults to "T0".

Notes:
**Requirements:**

- T0 requires Python 3.11 to 3.13 (via the
[`tfc-t0`](https://pypi.org/project/tfc-t0/) package).

**Available models:**

| Model ID | Parameters |
| ------------------------------------------------------------------------------------------------- | ---------- |
| [`theforecastingcompany/t0-alpha`](https://huggingface.co/theforecastingcompany/t0-alpha) | ~102M |

**Resources:**

- HuggingFace: [theforecastingcompany/t0-alpha](https://huggingface.co/theforecastingcompany/t0-alpha)
- Platform: [Retrocast](https://app.retrocast.com/)

**Technical Details:**

- The model is loaded onto the best available device (GPU if
available, otherwise CPU).
- T0 predicts 5 quantile knots (0.1, 0.25, 0.5, 0.75, 0.9); the
median (0.5) is used as the point forecast and other requested
quantiles are obtained by linear interpolation across the knots.
- NaN values in the context are treated as missing observations.
- T0 natively supports past and known-future covariates through its
`predict` API; this integration currently exposes the univariate
path only.
"""
self.repo_id = repo_id
self.context_length = context_length
self.batch_size = batch_size
self.alias = alias
self.device = "cuda" if torch.cuda.is_available() else "cpu"

@contextmanager
def _get_model(self) -> T0Forecaster:
model = T0Forecaster.from_pretrained(self.repo_id).to(self.device).eval()
try:
yield model
finally:
del model
torch.cuda.empty_cache()

def _to_context(self, batch: list[torch.Tensor]) -> torch.Tensor:
"""Left-pad a ragged batch with NaN (treated as missing by T0)."""
max_len = min(
max(len(ts) for ts in batch),
self.context_length,
)
context = torch.full(
(len(batch), max_len),
float("nan"),
dtype=torch.float32,
)
for idx, ts in enumerate(batch):
ts = ts[-max_len:]
context[idx, -len(ts) :] = ts.to(dtype=torch.float32)
return context

def forecast(
self,
df: pd.DataFrame,
h: int,
freq: str | None = None,
level: list[int | float] | None = None,
quantiles: list[float] | None = None,
) -> pd.DataFrame:
"""Generate forecasts for time series data using the model.

This method produces point forecasts and, optionally, prediction
intervals or quantile forecasts. The input DataFrame can contain one
or multiple time series in stacked (long) format.

Args:
df (pd.DataFrame):
DataFrame containing the time series to forecast. It must
include as columns:

- "unique_id": an ID column to distinguish multiple series.
- "ds": a time column indicating timestamps or periods.
- "y": a target column with the observed values.

h (int):
Forecast horizon specifying how many future steps to predict.
freq (str, optional):
Frequency of the time series (e.g. "D" for daily, "M" for
monthly). See [Pandas frequency aliases](https://pandas.pydata.org/
pandas-docs/stable/user_guide/timeseries.html#offset-aliases) for
valid values. If not provided, the frequency will be inferred
from the data.
level (list[int | float], optional):
Confidence levels for prediction intervals, expressed as
percentages (e.g. [80, 95]). If provided, the returned
DataFrame will include lower and upper interval columns for
each specified level.
quantiles (list[float], optional):
List of quantiles to forecast, expressed as floats between 0
and 1. Should not be used simultaneously with `level`. When
provided, the output DataFrame will contain additional columns
named in the format "model-q-{percentile}", where {percentile}
= 100 × quantile value. Quantiles the model wasn't trained on
are linearly interpolated across its fixed knots.

Returns:
pd.DataFrame:
DataFrame containing forecast results. Includes:

- point forecasts for each timestamp and series.
- prediction intervals if `level` is specified.
- quantile forecasts if `quantiles` is specified.

For multi-series data, the output retains the same unique
identifiers as the input DataFrame.
"""
freq = self._maybe_infer_freq(df, freq)
qc = QuantileConverter(level=level, quantiles=quantiles)
dataset = TimeSeriesDataset.from_df(df, batch_size=self.batch_size)
fcst_df = dataset.make_future_dataframe(h=h, freq=freq)
# T0 interpolates arbitrary quantile levels from its trained knots,
# so the median and any user-requested quantiles come from one pass.
pred_quantiles = sorted(set(qc.quantiles or []) | {0.5})
median_idx = pred_quantiles.index(0.5)
fcsts: list[np.ndarray] = []
with self._get_model() as model:
for batch in tqdm(dataset):
out = model.predict(
self._to_context(batch),
horizon=h,
quantiles=pred_quantiles,
)
# shape: (batch, h, n_quantiles)
fcsts.append(out.quantiles.cpu().numpy())
fcsts_np = np.concatenate(fcsts, axis=0)
fcst_df[self.alias] = fcsts_np[..., median_idx].reshape(-1, 1)
if qc.quantiles is not None:
for q in qc.quantiles:
fcst_df[f"{self.alias}-q-{int(q * 100)}"] = fcsts_np[
..., pred_quantiles.index(q)
].reshape(-1, 1)
fcst_df = qc.maybe_convert_quantiles_to_level(
fcst_df,
models=[self.alias],
)
return fcst_df