Update app.py
Browse files
app.py
CHANGED
|
@@ -21,22 +21,26 @@ def _round_df(df: pd.DataFrame, places: int = 2) -> pd.DataFrame:
|
|
| 21 |
return out
|
| 22 |
|
| 23 |
# ---------- Tool 1: Forecast ----------
|
|
|
|
|
|
|
|
|
|
| 24 |
@tool
|
| 25 |
-
def forecast_tool(horizon_months: int = 1, use_demo: bool = True, history_csv_path: str = ""
|
|
|
|
| 26 |
"""
|
| 27 |
-
Forecast monthly demand
|
| 28 |
|
| 29 |
Args:
|
| 30 |
-
horizon_months (int): Number of future months to forecast (>=1).
|
| 31 |
-
use_demo (bool): If True,
|
| 32 |
-
history_csv_path (str): Optional CSV
|
|
|
|
|
|
|
| 33 |
|
| 34 |
Returns:
|
| 35 |
-
str: JSON
|
| 36 |
"""
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
# 1) History
|
| 40 |
if use_demo or not history_csv_path:
|
| 41 |
rng = pd.date_range("2023-01-01", periods=24, freq="MS")
|
| 42 |
rows = []
|
|
@@ -49,22 +53,75 @@ def forecast_tool(horizon_months: int = 1, use_demo: bool = True, history_csv_pa
|
|
| 49 |
df = pd.DataFrame(rows)
|
| 50 |
else:
|
| 51 |
df = pd.read_csv(history_csv_path)
|
| 52 |
-
assert {"product_id",
|
| 53 |
df["date"] = pd.to_datetime(df["date"], errors="coerce")
|
| 54 |
df = df.dropna(subset=["date"])
|
| 55 |
df["qty"] = pd.to_numeric(df["qty"], errors="coerce").fillna(0.0)
|
| 56 |
|
| 57 |
-
#
|
| 58 |
-
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
for pid, g in df.groupby("product_id"):
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
return json.dumps(out)
|
| 69 |
|
| 70 |
# ---------- Tool 2: Optimize (LP) ----------
|
|
|
|
| 21 |
return out
|
| 22 |
|
| 23 |
# ---------- Tool 1: Forecast ----------
|
| 24 |
+
from smolagents import tool
|
| 25 |
+
import json, pandas as pd, numpy as np
|
| 26 |
+
|
| 27 |
@tool
|
| 28 |
+
def forecast_tool(horizon_months: int = 1, use_demo: bool = True, history_csv_path: str = "",
|
| 29 |
+
use_covariates: bool = False) -> str:
|
| 30 |
"""
|
| 31 |
+
Forecast monthly demand using a GLOBAL N-HiTS model (fast & accurate).
|
| 32 |
|
| 33 |
Args:
|
| 34 |
+
horizon_months (int): Number of future months to forecast (>=1).
|
| 35 |
+
use_demo (bool): If True, generates synthetic history for FG100/FG200.
|
| 36 |
+
history_csv_path (str): Optional CSV with columns [product_id,date,qty,(optional extra covariates...)].
|
| 37 |
+
use_covariates (bool): If True and extra numeric columns exist, use them as past covariates
|
| 38 |
+
(for future effects you must provide future values too).
|
| 39 |
|
| 40 |
Returns:
|
| 41 |
+
str: JSON list of {"product_id","period_start","forecast_qty"} for the next horizon_months.
|
| 42 |
"""
|
| 43 |
+
# --- 1) Load data in long form ---
|
|
|
|
|
|
|
| 44 |
if use_demo or not history_csv_path:
|
| 45 |
rng = pd.date_range("2023-01-01", periods=24, freq="MS")
|
| 46 |
rows = []
|
|
|
|
| 53 |
df = pd.DataFrame(rows)
|
| 54 |
else:
|
| 55 |
df = pd.read_csv(history_csv_path)
|
| 56 |
+
assert {"product_id","date","qty"} <= set(df.columns), "CSV must have product_id,date,qty"
|
| 57 |
df["date"] = pd.to_datetime(df["date"], errors="coerce")
|
| 58 |
df = df.dropna(subset=["date"])
|
| 59 |
df["qty"] = pd.to_numeric(df["qty"], errors="coerce").fillna(0.0)
|
| 60 |
|
| 61 |
+
# Ensure proper monthly frequency per SKU
|
| 62 |
+
df = df.copy()
|
| 63 |
+
df["product_id"] = df["product_id"].astype(str)
|
| 64 |
+
|
| 65 |
+
# --- 2) Build Darts series (GLOBAL model across SKUs) ---
|
| 66 |
+
from darts import TimeSeries
|
| 67 |
+
series_list = []
|
| 68 |
+
past_cov_list = [] # optional
|
| 69 |
+
|
| 70 |
+
extra_cols = [c for c in df.columns if c not in ["product_id","date","qty"]]
|
| 71 |
+
# keep only numeric covariates (categoricals must be pre-encoded)
|
| 72 |
+
num_covs = [c for c in extra_cols if pd.api.types.is_numeric_dtype(df[c])]
|
| 73 |
+
|
| 74 |
for pid, g in df.groupby("product_id"):
|
| 75 |
+
g = (g.set_index("date")
|
| 76 |
+
.sort_index()
|
| 77 |
+
.resample("MS")
|
| 78 |
+
.agg({**{"qty":"sum"}, **{c:"last" for c in num_covs}})
|
| 79 |
+
.fillna(method="ffill")
|
| 80 |
+
.fillna(0.0))
|
| 81 |
+
y = TimeSeries.from_dataframe(g.reset_index(), time_col="date", value_cols="qty", freq="MS")
|
| 82 |
+
series_list.append(y)
|
| 83 |
+
|
| 84 |
+
if use_covariates and num_covs:
|
| 85 |
+
pc = TimeSeries.from_dataframe(g.reset_index(), time_col="date", value_cols=num_covs, freq="MS")
|
| 86 |
+
past_cov_list.append(pc)
|
| 87 |
+
else:
|
| 88 |
+
past_cov_list.append(None)
|
| 89 |
+
|
| 90 |
+
# --- 3) Train N-HiTS (fast settings) ---
|
| 91 |
+
from darts.models import NHiTSModel
|
| 92 |
+
|
| 93 |
+
H = max(1, int(horizon_months))
|
| 94 |
+
# keep chunk length small for short histories; model is global
|
| 95 |
+
input_chunk = max(6, min(12, min(len(s) for s in series_list) - 1)) if series_list else 12
|
| 96 |
+
|
| 97 |
+
model = NHiTSModel(
|
| 98 |
+
input_chunk_length=input_chunk,
|
| 99 |
+
output_chunk_length=min(H, 3), # can roll to reach H
|
| 100 |
+
n_epochs=60, # keep fast; tune up if needed
|
| 101 |
+
batch_size=32,
|
| 102 |
+
random_state=0,
|
| 103 |
+
dropout=0.0,
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
if use_covariates and any(pc is not None for pc in past_cov_list):
|
| 107 |
+
model.fit(series=series_list, past_covariates=past_cov_list, verbose=False)
|
| 108 |
+
else:
|
| 109 |
+
model.fit(series=series_list, verbose=False)
|
| 110 |
+
|
| 111 |
+
# --- 4) Predict per SKU and return JSON ---
|
| 112 |
+
out = []
|
| 113 |
+
for pid, s, pc in zip(df["product_id"].unique(), series_list, past_cov_list):
|
| 114 |
+
if use_covariates and pc is not None:
|
| 115 |
+
pred = model.predict(n=H, series=s, past_covariates=pc)
|
| 116 |
+
else:
|
| 117 |
+
pred = model.predict(n=H, series=s)
|
| 118 |
+
for ts, val in zip(pred.time_index, pred.values().flatten()):
|
| 119 |
+
out.append({
|
| 120 |
+
"product_id": str(pid),
|
| 121 |
+
"period_start": pd.Timestamp(ts).strftime("%Y-%m-%d"),
|
| 122 |
+
"forecast_qty": float(val)
|
| 123 |
+
})
|
| 124 |
+
|
| 125 |
return json.dumps(out)
|
| 126 |
|
| 127 |
# ---------- Tool 2: Optimize (LP) ----------
|