PD03 commited on
Commit
840ed4e
·
verified ·
1 Parent(s): 748249e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -20
app.py CHANGED
@@ -21,22 +21,26 @@ def _round_df(df: pd.DataFrame, places: int = 2) -> pd.DataFrame:
21
  return out
22
 
23
  # ---------- Tool 1: Forecast ----------
 
 
 
24
  @tool
25
- def forecast_tool(horizon_months: int = 1, use_demo: bool = True, history_csv_path: str = "") -> str:
 
26
  """
27
- Forecast monthly demand for finished goods using Prophet (demo-friendly).
28
 
29
  Args:
30
- horizon_months (int): Number of future months to forecast (>=1). Defaults to 1.
31
- use_demo (bool): If True, generate synthetic history for FG100/FG200. Defaults to True.
32
- history_csv_path (str): Optional CSV path with columns [product_id,date,qty] to override demo.
 
 
33
 
34
  Returns:
35
- str: JSON string list of {"product_id": str, "period_start": "YYYY-MM-01", "forecast_qty": float}.
36
  """
37
- from prophet import Prophet
38
-
39
- # 1) History
40
  if use_demo or not history_csv_path:
41
  rng = pd.date_range("2023-01-01", periods=24, freq="MS")
42
  rows = []
@@ -49,22 +53,75 @@ def forecast_tool(horizon_months: int = 1, use_demo: bool = True, history_csv_pa
49
  df = pd.DataFrame(rows)
50
  else:
51
  df = pd.read_csv(history_csv_path)
52
- assert {"product_id", "date", "qty"} <= set(df.columns), "CSV must have product_id,date,qty"
53
  df["date"] = pd.to_datetime(df["date"], errors="coerce")
54
  df = df.dropna(subset=["date"])
55
  df["qty"] = pd.to_numeric(df["qty"], errors="coerce").fillna(0.0)
56
 
57
- # 2) Forecast per product
58
- out = []
59
- H = max(1, int(horizon_months))
 
 
 
 
 
 
 
 
 
 
60
  for pid, g in df.groupby("product_id"):
61
- s = (g.set_index("date")["qty"].resample("MS").sum().asfreq("MS").fillna(0.0))
62
- m = Prophet(yearly_seasonality=True, weekly_seasonality=False, daily_seasonality=False, n_changepoints=10)
63
- m.fit(pd.DataFrame({"ds": s.index, "y": s.values}))
64
- future = m.make_future_dataframe(periods=H, freq="MS", include_history=False)
65
- pred = m.predict(future)[["ds", "yhat"]]
66
- for _, r in pred.iterrows():
67
- out.append({"product_id": str(pid), "period_start": r["ds"].strftime("%Y-%m-%d"), "forecast_qty": float(r["yhat"])})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  return json.dumps(out)
69
 
70
  # ---------- Tool 2: Optimize (LP) ----------
 
21
  return out
22
 
23
  # ---------- Tool 1: Forecast ----------
24
+ from smolagents import tool
25
+ import json, pandas as pd, numpy as np
26
+
27
  @tool
28
+ def forecast_tool(horizon_months: int = 1, use_demo: bool = True, history_csv_path: str = "",
29
+ use_covariates: bool = False) -> str:
30
  """
31
+ Forecast monthly demand using a GLOBAL N-HiTS model (fast & accurate).
32
 
33
  Args:
34
+ horizon_months (int): Number of future months to forecast (>=1).
35
+ use_demo (bool): If True, generates synthetic history for FG100/FG200.
36
+ history_csv_path (str): Optional CSV with columns [product_id,date,qty,(optional extra covariates...)].
37
+ use_covariates (bool): If True and extra numeric columns exist, use them as past covariates
38
+ (for future effects you must provide future values too).
39
 
40
  Returns:
41
+ str: JSON list of {"product_id","period_start","forecast_qty"} for the next horizon_months.
42
  """
43
+ # --- 1) Load data in long form ---
 
 
44
  if use_demo or not history_csv_path:
45
  rng = pd.date_range("2023-01-01", periods=24, freq="MS")
46
  rows = []
 
53
  df = pd.DataFrame(rows)
54
  else:
55
  df = pd.read_csv(history_csv_path)
56
+ assert {"product_id","date","qty"} <= set(df.columns), "CSV must have product_id,date,qty"
57
  df["date"] = pd.to_datetime(df["date"], errors="coerce")
58
  df = df.dropna(subset=["date"])
59
  df["qty"] = pd.to_numeric(df["qty"], errors="coerce").fillna(0.0)
60
 
61
+ # Ensure proper monthly frequency per SKU
62
+ df = df.copy()
63
+ df["product_id"] = df["product_id"].astype(str)
64
+
65
+ # --- 2) Build Darts series (GLOBAL model across SKUs) ---
66
+ from darts import TimeSeries
67
+ series_list = []
68
+ past_cov_list = [] # optional
69
+
70
+ extra_cols = [c for c in df.columns if c not in ["product_id","date","qty"]]
71
+ # keep only numeric covariates (categoricals must be pre-encoded)
72
+ num_covs = [c for c in extra_cols if pd.api.types.is_numeric_dtype(df[c])]
73
+
74
  for pid, g in df.groupby("product_id"):
75
+ g = (g.set_index("date")
76
+ .sort_index()
77
+ .resample("MS")
78
+ .agg({**{"qty":"sum"}, **{c:"last" for c in num_covs}})
79
+ .fillna(method="ffill")
80
+ .fillna(0.0))
81
+ y = TimeSeries.from_dataframe(g.reset_index(), time_col="date", value_cols="qty", freq="MS")
82
+ series_list.append(y)
83
+
84
+ if use_covariates and num_covs:
85
+ pc = TimeSeries.from_dataframe(g.reset_index(), time_col="date", value_cols=num_covs, freq="MS")
86
+ past_cov_list.append(pc)
87
+ else:
88
+ past_cov_list.append(None)
89
+
90
+ # --- 3) Train N-HiTS (fast settings) ---
91
+ from darts.models import NHiTSModel
92
+
93
+ H = max(1, int(horizon_months))
94
+ # keep chunk length small for short histories; model is global
95
+ input_chunk = max(6, min(12, min(len(s) for s in series_list) - 1)) if series_list else 12
96
+
97
+ model = NHiTSModel(
98
+ input_chunk_length=input_chunk,
99
+ output_chunk_length=min(H, 3), # can roll to reach H
100
+ n_epochs=60, # keep fast; tune up if needed
101
+ batch_size=32,
102
+ random_state=0,
103
+ dropout=0.0,
104
+ )
105
+
106
+ if use_covariates and any(pc is not None for pc in past_cov_list):
107
+ model.fit(series=series_list, past_covariates=past_cov_list, verbose=False)
108
+ else:
109
+ model.fit(series=series_list, verbose=False)
110
+
111
+ # --- 4) Predict per SKU and return JSON ---
112
+ out = []
113
+ for pid, s, pc in zip(df["product_id"].unique(), series_list, past_cov_list):
114
+ if use_covariates and pc is not None:
115
+ pred = model.predict(n=H, series=s, past_covariates=pc)
116
+ else:
117
+ pred = model.predict(n=H, series=s)
118
+ for ts, val in zip(pred.time_index, pred.values().flatten()):
119
+ out.append({
120
+ "product_id": str(pid),
121
+ "period_start": pd.Timestamp(ts).strftime("%Y-%m-%d"),
122
+ "forecast_qty": float(val)
123
+ })
124
+
125
  return json.dumps(out)
126
 
127
  # ---------- Tool 2: Optimize (LP) ----------