johnbridges commited on
Commit
67ec8f1
·
1 Parent(s): 628cf4f
Files changed (1) hide show
  1. timesfs_backend.py +223 -0
timesfs_backend.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # timesfm_backend.py
2
+ import time, logging
3
+ from typing import Any, Dict, List, Optional, Tuple
4
+
5
+ import torch
6
+
7
+ try:
8
+ # If you install an official TimesFM package later, we’ll try to use it.
9
+ # (e.g., `pip install timesfm` if/when available)
10
+ import timesfm as tsm # type: ignore
11
+ except Exception:
12
+ tsm = None # graceful fallback
13
+
14
+ try:
15
+ # Optional: pull weights from HF if you want local inference
16
+ # pip install huggingface_hub
17
+ from huggingface_hub import snapshot_download
18
+ except Exception:
19
+ snapshot_download = None # optional
20
+
21
+ from backends_base import ImagesBackend # to mirror structure; not used here
22
+ from config import settings
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+ # --------------------------------------------------------------------------------------
27
+ # Config
28
+ # --------------------------------------------------------------------------------------
29
+ MODEL_ID = getattr(settings, "LlmHFModelID", None) or "google/timesfm-2.5-200m-pytorch"
30
+ DEFAULT_HORIZON = 24 # sensible default if caller omits
31
+ DEFAULT_FREQ = "H" # hour
32
+ ALLOW_GPU = True
33
+
34
+ # --------------------------------------------------------------------------------------
35
+ # Helpers
36
+ # --------------------------------------------------------------------------------------
37
+ def _pick_device() -> str:
38
+ if ALLOW_GPU and torch.cuda.is_available():
39
+ return "cuda"
40
+ return "cpu"
41
+
42
+ def _pick_dtype(device: str) -> torch.dtype:
43
+ # FP16 on CUDA, FP32 on CPU by default (safe and simple)
44
+ if device != "cpu":
45
+ return torch.float16
46
+ return torch.float32
47
+
48
+ def _as_1d_float_tensor(series: List[float], device: str, dtype: torch.dtype) -> torch.Tensor:
49
+ t = torch.tensor(series, dtype=torch.float32) # keep input parse stable
50
+ return t.to(device=device, dtype=dtype)
51
+
52
+ # --------------------------------------------------------------------------------------
53
+ # Fallback forecaster (naive)
54
+ # --------------------------------------------------------------------------------------
55
+ def _naive_forecast(x: torch.Tensor, horizon: int) -> torch.Tensor:
56
+ """
57
+ Very simple fallback: repeat the last observed value for H steps.
58
+ Ensures the backend returns a forecast even without TimesFM installed.
59
+ """
60
+ last = x[-1] if x.numel() > 0 else torch.tensor(0.0, device=x.device, dtype=x.dtype)
61
+ return last.repeat(horizon).to(dtype=x.dtype, device=x.device)
62
+
63
+ # --------------------------------------------------------------------------------------
64
+ # Backend
65
+ # --------------------------------------------------------------------------------------
66
+ class TimesFMBackend:
67
+ """
68
+ Minimal forecasting backend. Input request (dict) shape:
69
+
70
+ {
71
+ "series": [float, ...], # required
72
+ "horizon": 48, # optional (default 24)
73
+ "freq": "H", # optional (default "H")
74
+ "normalize": true, # optional
75
+ "model_id": "google/...", # optional override
76
+ "use_gpu": true/false # optional
77
+ }
78
+
79
+ Output (dict):
80
+ {
81
+ "id": "tsfcst-...",
82
+ "object": "timeseries.forecast",
83
+ "created": 1234567890,
84
+ "model": "<model_id>",
85
+ "horizon": H,
86
+ "freq": "H",
87
+ "forecast": [float, ...],
88
+ "backend": "timesfm",
89
+ "note": "fallback-naive" # only when naive path used
90
+ }
91
+ """
92
+
93
+ def __init__(self) -> None:
94
+ self._model = None
95
+ self._model_id = MODEL_ID
96
+ self._device = _pick_device()
97
+ self._dtype = _pick_dtype(self._device)
98
+ logger.info(f"[timesfm] init: model_id={self._model_id} device={self._device} dtype={self._dtype}")
99
+
100
+ # ---------- model load (best-effort) ----------
101
+ def _ensure_model(self, model_id: Optional[str] = None) -> None:
102
+ if self._model is not None and (not model_id or model_id == self._model_id):
103
+ return
104
+
105
+ want_id = model_id or self._model_id
106
+ self._model_id = want_id
107
+
108
+ if tsm is None:
109
+ logger.warning("[timesfm] timesfm package not available; using naive fallback")
110
+ self._model = None
111
+ return
112
+
113
+ # If the library provides a from_pretrained, use it; else attempt HF snapshot and custom load.
114
+ model = None
115
+ try:
116
+ if hasattr(tsm, "TimesFM") and hasattr(tsm.TimesFM, "from_pretrained"):
117
+ logger.info(f"[timesfm] loading via TimesFM.from_pretrained('{want_id}')")
118
+ model = tsm.TimesFM.from_pretrained(want_id) # type: ignore[attr-defined]
119
+ else:
120
+ # Manual path: download and let user wire loading code for their saved format
121
+ if snapshot_download is None:
122
+ raise RuntimeError("huggingface_hub not installed; cannot pull weights")
123
+ logger.info(f"[timesfm] snapshot_download('{want_id}')")
124
+ local_dir = snapshot_download(repo_id=want_id)
125
+ # TODO: Replace with actual load for the repo format if needed.
126
+ # Placeholder: try to import a generic torch file if present.
127
+ logger.warning(f"[timesfm] no direct loader available; using naive fallback. weights at {local_dir}")
128
+ model = None
129
+ except Exception as e:
130
+ logger.warning(f"[timesfm] failed to load model '{want_id}': {e}. Falling back to naive.")
131
+ model = None
132
+
133
+ self._model = model
134
+ if model is not None:
135
+ try:
136
+ self._model.to(self._device) # type: ignore[operator]
137
+ except Exception:
138
+ pass
139
+ logger.info("[timesfm] model ready on %s", self._device)
140
+
141
+ # ---------- public API ----------
142
+ async def forecast(self, request: Dict[str, Any]) -> Dict[str, Any]:
143
+ """
144
+ Async to match your other backends. Returns a single, non-streaming result dict.
145
+ """
146
+ # parse inputs
147
+ model_id = request.get("model") or request.get("model_id") or self._model_id
148
+ series = request.get("series")
149
+ horizon = int(request.get("horizon") or DEFAULT_HORIZON)
150
+ freq = request.get("freq") or DEFAULT_FREQ
151
+ normalize = bool(request.get("normalize") or False)
152
+ use_gpu = request.get("use_gpu")
153
+ if use_gpu is not None:
154
+ self._device = "cuda" if (use_gpu and torch.cuda.is_available()) else "cpu"
155
+ self._dtype = _pick_dtype(self._device)
156
+
157
+ if not isinstance(series, (list, tuple)) or not all(isinstance(v, (int, float)) for v in series):
158
+ raise ValueError("request['series'] must be a list of numbers")
159
+
160
+ # ensure model (or fallback)
161
+ self._ensure_model(model_id)
162
+
163
+ # tensorize
164
+ x = _as_1d_float_tensor(list(series), self._device, self._dtype)
165
+
166
+ # optional normalization (z-score)
167
+ mu: Optional[torch.Tensor] = None
168
+ sigma: Optional[torch.Tensor] = None
169
+ if normalize and x.numel() > 1:
170
+ mu = x.mean()
171
+ sigma = x.std(unbiased=False).clamp_min(1e-6)
172
+ x_norm = (x - mu) / sigma
173
+ else:
174
+ x_norm = x
175
+
176
+ # run forecast
177
+ note = None
178
+ if self._model is None:
179
+ y_hat = _naive_forecast(x_norm, horizon)
180
+ note = "fallback-naive"
181
+ else:
182
+ try:
183
+ # Preferred path if the library supports it:
184
+ if hasattr(self._model, "forecast"):
185
+ y_hat = self._model.forecast(x_norm.unsqueeze(0), horizon=horizon) # type: ignore[attr-defined]
186
+ # Shape handling: [B, H] -> 1D
187
+ if isinstance(y_hat, (list, tuple)):
188
+ y_hat = torch.tensor(y_hat, device=x_norm.device, dtype=x_norm.dtype)
189
+ if isinstance(y_hat, torch.Tensor) and y_hat.dim() == 2:
190
+ y_hat = y_hat[0]
191
+ elif not isinstance(y_hat, torch.Tensor):
192
+ y_hat = torch.tensor(y_hat, device=x_norm.device, dtype=x_norm.dtype)
193
+ else:
194
+ # If no forecast method, fallback
195
+ y_hat = _naive_forecast(x_norm, horizon)
196
+ note = "fallback-naive"
197
+ except Exception as e:
198
+ logger.warning(f"[timesfm] forecast failed on model path: {e}. Using naive fallback.")
199
+ y_hat = _naive_forecast(x_norm, horizon)
200
+ note = "fallback-naive"
201
+
202
+ # denormalize
203
+ if normalize and mu is not None and sigma is not None:
204
+ y_hat = y_hat * sigma + mu
205
+
206
+ # move to cpu list
207
+ forecast = y_hat.detach().float().cpu().tolist()
208
+
209
+ rid = f"tsfcst-{int(time.time())}"
210
+ now = int(time.time())
211
+ resp = {
212
+ "id": rid,
213
+ "object": "timeseries.forecast",
214
+ "created": now,
215
+ "model": self._model_id,
216
+ "horizon": horizon,
217
+ "freq": freq,
218
+ "forecast": forecast,
219
+ "backend": "timesfm",
220
+ }
221
+ if note:
222
+ resp["note"] = note
223
+ return resp