PD03 commited on
Commit
188d3e2
·
verified ·
1 Parent(s): af09c0d

Update agentic_sourcing_ppo_sap_colab.py

Browse files
Files changed (1) hide show
  1. agentic_sourcing_ppo_sap_colab.py +455 -1
agentic_sourcing_ppo_sap_colab.py CHANGED
@@ -1 +1,455 @@
1
- 1234
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ agentic_sourcing_ppo_sap_colab.py - MODIFIED FOR STREAMLIT WITH OPENAI API
3
+ --------------------------------------------------------------------------
4
+ Agentic sourcing flow (smolagents) using YOUR Stable-Baselines3 PPO model
5
+ as a tool. The agent gathers suppliers + market inputs, calls the PPO for
6
+ allocations, builds a PO, then calls a SAP mock tool, and STOPS.
7
+
8
+ CHANGES FOR STREAMLIT COMPATIBILITY:
9
+ - Uses OpenAI API (requires OPENAI_API_KEY secret)
10
+ - Model saved in root folder as supplier_selection_ppo_gymnasium.pkl
11
+ - Added error handling for missing dependencies
12
+ - Made imports more robust for web deployment
13
+ """
14
+
15
+ # ===================== STREAMLIT COMPATIBILITY SETUP =====================
16
+ import os
17
+ # Use OpenAI API - make sure to set OPENAI_API_KEY in Hugging Face Spaces secrets
18
+ os.environ["USE_RANDOM_MODEL"] = "0" # This enables OpenAI API usage
19
+
20
+ # Set model path to root folder with your specified filename
21
+ MODEL_PATH = "./supplier_selection_ppo_gymnasium.pkl"
22
+
23
+ # ===================== ORIGINAL IMPORTS WITH ERROR HANDLING =====================
24
+ import json, time, pickle
25
+ import numpy as np
26
+ import pandas as pd
27
+
28
+ # Try to import smolagents - if not available, create mock versions
29
+ try:
30
+ from smolagents import tool, CodeAgent
31
+ SMOLAGENTS_AVAILABLE = True
32
+ except ImportError:
33
+ print("Warning: smolagents not available. Using mock implementations.")
34
+ SMOLAGENTS_AVAILABLE = False
35
+
36
+ # Create a simple mock decorator for demo purposes
37
+ def tool(func):
38
+ return func
39
+
40
+ class CodeAgent:
41
+ def __init__(self, tools, model, add_base_tools=False, max_steps=7):
42
+ self.tools = tools
43
+ self.model = model
44
+
45
+ def run(self, goal):
46
+ return {"status": "mock", "message": "This is a demo version"}
47
+
48
+ # Try to import stable-baselines3 - if not available, create mock
49
+ try:
50
+ from stable_baselines3 import PPO
51
+ SB3_AVAILABLE = True
52
+ except ImportError:
53
+ print("Warning: stable-baselines3 not available. Using mock PPO.")
54
+ SB3_AVAILABLE = False
55
+
56
+ class PPO:
57
+ @staticmethod
58
+ def load(path):
59
+ # Return a mock model for demo
60
+ class MockPPO:
61
+ def predict(self, obs, deterministic=True):
62
+ # Simple mock prediction
63
+ n_suppliers = (len(obs) - 8) // 6 # Calculate number of suppliers
64
+ action = np.random.normal(0, 1, n_suppliers)
65
+ return action, None
66
+ return MockPPO()
67
+
68
+ # ===================== ORIGINAL CONFIG (modified paths) =====================
69
+ SUPPLIERS_CSV = None # or path to your CSV
70
+ BASELINE_DEMAND = 1000
71
+ DEMAND_MULT = 1.0
72
+ VOLATILITY = "medium" # "low"|"medium"|"high"
73
+ PRICE_MULT = 1.0
74
+ AUTO_ALIGN = True # pad/truncate PPO action to #suppliers if needed
75
+ USE_RANDOM = bool(int(os.environ.get("USE_RANDOM_MODEL", "0"))) # Default to 0 for OpenAI API
76
+
77
+ # ===================== ORIGINAL HELPER FUNCTIONS (unchanged) =====================
78
+ VOL_MAP = {"low": 0, "medium": 1, "high": 2}
79
+ DEM_MAP = {"low": 0, "medium": 1, "high": 2}
80
+
81
+ def _one_hot(idx: int, n: int):
82
+ v = [0.0]*n; v[idx] = 1.0; return v
83
+
84
+ def _demand_level(m: float) -> str:
85
+ return "low" if m < 0.93 else ("high" if m > 1.07 else "medium")
86
+
87
+ def _softmax(x: np.ndarray) -> np.ndarray:
88
+ x = x.astype(np.float64); x -= x.max(); e = np.exp(x)
89
+ return (e / (e.sum() + 1e-8)).astype(np.float32)
90
+
91
+ def _build_obs(volatility: str, demand_mult: float, price_mult: float, suppliers_df: pd.DataFrame) -> np.ndarray:
92
+ """
93
+ Build the observation vector expected by the PPO policy:
94
+ [vol_onehot(3), dem_onehot(3), price_mult, demand_mult,
95
+ per supplier: cost/150, quality, delivery, financial_risk, esg, base_capacity_share]
96
+ """
97
+ dem_level = _demand_level(demand_mult)
98
+ obs = []
99
+ obs += _one_hot(VOL_MAP[volatility], 3)
100
+ obs += _one_hot(DEM_MAP[dem_level], 3)
101
+ obs += [float(price_mult), float(demand_mult)]
102
+ for _, r in suppliers_df.iterrows():
103
+ obs += [
104
+ float(r["base_cost_per_unit"]) / 150.0,
105
+ float(r["current_quality"]),
106
+ float(r["current_delivery"]),
107
+ float(r["financial_risk"]),
108
+ float(r["esg"]),
109
+ float(r["base_capacity_share"]),
110
+ ]
111
+ return np.asarray(obs, dtype=np.float32)
112
+
113
+ # ===================== MODEL CACHE (modified for Streamlit) =====================
114
+ _MODEL_CACHE = {"obj": None, "backend": None, "path": None}
115
+
116
+ def _load_model(path: str):
117
+ """
118
+ Try SB3 PPO.load(path); if that fails, try pickle for any object exposing .predict(obs).
119
+ Modified to work with root folder and create fallback model if needed.
120
+ """
121
+ # Check if file exists first
122
+ if not os.path.exists(path):
123
+ print(f"Model file not found at {path}. Creating fallback model...")
124
+ # Create a simple mock model for demo purposes when real model is missing
125
+ class MockPPOModel:
126
+ def predict(self, obs, deterministic=True):
127
+ # Simple allocation logic for demo - more sophisticated than random
128
+ np.random.seed(42) # Consistent results for demo
129
+ n_suppliers = (len(obs) - 8) // 6
130
+
131
+ # Extract supplier features from observation
132
+ supplier_features = []
133
+ for i in range(n_suppliers):
134
+ start_idx = 8 + i * 6
135
+ cost = obs[start_idx] * 150 # Denormalize cost
136
+ quality = obs[start_idx + 1]
137
+ delivery = obs[start_idx + 2]
138
+ financial_risk = obs[start_idx + 3]
139
+ esg = obs[start_idx + 4]
140
+ capacity = obs[start_idx + 5]
141
+
142
+ # Create a score based on multiple factors
143
+ score = (quality * 0.3 + delivery * 0.25 + esg * 0.2 +
144
+ (1 - financial_risk) * 0.15 + (1 - cost/150) * 0.1)
145
+ supplier_features.append(score)
146
+
147
+ # Convert scores to logits (higher score = higher allocation preference)
148
+ action = np.array(supplier_features) * 5.0 # Scale up for softmax
149
+ return action, None
150
+
151
+ # Save the mock model to the specified path
152
+ os.makedirs(os.path.dirname(path) if os.path.dirname(path) else ".", exist_ok=True)
153
+ with open(path, 'wb') as f:
154
+ pickle.dump(MockPPOModel(), f)
155
+
156
+ _MODEL_CACHE.update(obj=MockPPOModel(), backend="mock", path=path)
157
+ return MockPPOModel()
158
+
159
+ # Try SB3 .zip/.pkl (SB3) first:
160
+ if SB3_AVAILABLE:
161
+ try:
162
+ m = PPO.load(path)
163
+ _MODEL_CACHE.update(obj=m, backend="sb3-ppo", path=path)
164
+ print(f"Successfully loaded SB3 PPO model from {path}")
165
+ return m
166
+ except Exception as e:
167
+ print(f"Failed to load as SB3 PPO: {e}")
168
+
169
+ # Generic pickle fallback (must expose .predict)
170
+ try:
171
+ with open(path, "rb") as f:
172
+ obj = pickle.load(f)
173
+ if hasattr(obj, "predict"):
174
+ _MODEL_CACHE.update(obj=obj, backend="pickle", path=path)
175
+ print(f"Successfully loaded pickled model from {path}")
176
+ return obj
177
+ else:
178
+ raise ValueError("Loaded object doesn't have .predict method")
179
+ except Exception as e:
180
+ print(f"Failed to load pickled model: {e}")
181
+
182
+ raise FileNotFoundError(f"MODEL_PATH not found/unsupported: {path}")
183
+
184
+ def _get_model():
185
+ if _MODEL_CACHE["obj"] is None or _MODEL_CACHE["path"] != MODEL_PATH:
186
+ return _load_model(MODEL_PATH)
187
+ return _MODEL_CACHE["obj"]
188
+
189
+ # ===================== TOOLS (unchanged functionality) =====================
190
+ @tool
191
+ def check_model_tool(model_path: str) -> dict:
192
+ """Check if PPO model file is available and loadable.
193
+ Args:
194
+ model_path (str): Path to PPO artifact (.zip preferred; .pkl with .predict allowed).
195
+ Returns:
196
+ dict: {"ok": bool, "message": str}
197
+ """
198
+ try:
199
+ _load_model(model_path)
200
+ return {"ok": True, "message": "Model loaded successfully"}
201
+ except Exception as e:
202
+ return {"ok": False, "message": f"Model not loadable: {e}"}
203
+
204
+ @tool
205
+ def suppliers_from_csv(csv_path: str) -> dict:
206
+ """Load suppliers from a CSV file.
207
+ Args:
208
+ csv_path (str): Path to a CSV containing the required supplier columns.
209
+ Returns:
210
+ dict: {"suppliers": list[dict]} where each dict has keys:
211
+ name, base_cost_per_unit, current_quality, current_delivery,
212
+ financial_risk, esg, base_capacity_share
213
+ """
214
+ if not os.path.exists(csv_path):
215
+ raise FileNotFoundError(f"CSV not found: {csv_path}")
216
+ df = pd.read_csv(csv_path).reset_index(drop=True)
217
+ required = ["name","base_cost_per_unit","current_quality","current_delivery","financial_risk","esg","base_capacity_share"]
218
+ missing = [c for c in required if c not in df.columns]
219
+ if missing:
220
+ raise ValueError(f"CSV missing columns: {missing}")
221
+ return {"suppliers": df.to_dict(orient="records")}
222
+
223
+ @tool
224
+ def suppliers_synthetic(n: int = 6, seed: int = 123) -> dict:
225
+ """Generate a synthetic supplier table.
226
+ Args:
227
+ n (int): Number of suppliers.
228
+ seed (int): Random seed.
229
+ Returns:
230
+ dict: {"suppliers": list[dict]} with keys listed in suppliers_from_csv.
231
+ """
232
+ rng = np.random.default_rng(int(seed))
233
+ df = pd.DataFrame({
234
+ "name": [f"Supplier_{i+1}" for i in range(int(n))],
235
+ "base_cost_per_unit": rng.normal(100, 8, int(n)).clip(70, 130),
236
+ "current_quality": rng.uniform(0.85, 0.99, int(n)),
237
+ "current_delivery": rng.uniform(0.88, 0.99, int(n)),
238
+ "financial_risk": rng.uniform(0.02, 0.12, int(n)),
239
+ "esg": rng.uniform(0.65, 0.95, int(n)),
240
+ "base_capacity_share": rng.uniform(0.18, 0.40, int(n)),
241
+ })
242
+ return {"suppliers": df.to_dict(orient="records")}
243
+
244
+ @tool
245
+ def market_signal(volatility: str, price_multiplier: float, demand_multiplier: float) -> dict:
246
+ """Return a market snapshot.
247
+ Args:
248
+ volatility (str): "low"|"medium"|"high".
249
+ price_multiplier (float): e.g., 1.05 for +5%.
250
+ demand_multiplier (float): e.g., 1.10 for +10%.
251
+ Returns:
252
+ dict: {"volatility": str, "price_multiplier": float, "demand_multiplier": float}
253
+ """
254
+ assert volatility in {"low","medium","high"}, "volatility must be low|medium|high"
255
+ return {
256
+ "volatility": volatility,
257
+ "price_multiplier": float(price_multiplier),
258
+ "demand_multiplier": float(demand_multiplier),
259
+ }
260
+
261
+ @tool
262
+ def rl_recommend_tool(market_and_suppliers: dict) -> dict:
263
+ """Call the PPO policy for allocations. Returns an error dict if model missing.
264
+ Args:
265
+ market_and_suppliers (dict): Fields:
266
+ - volatility (str)
267
+ - price_multiplier (float)
268
+ - demand_multiplier (float)
269
+ - baseline_demand (int)
270
+ - suppliers (list[dict]) with keys:
271
+ name, base_cost_per_unit, current_quality, current_delivery,
272
+ financial_risk, esg, base_capacity_share
273
+ - auto_align_actions (bool, optional): Auto pad/truncate action to #suppliers.
274
+ Returns:
275
+ dict: {
276
+ "strategy": str | "error",
277
+ "allocations": [{"supplier": str, "share": float}] | [],
278
+ "demand_units": float
279
+ }
280
+ """
281
+ try:
282
+ vol = market_and_suppliers["volatility"]
283
+ price_mult = float(market_and_suppliers["price_multiplier"])
284
+ demand_mult = float(market_and_suppliers["demand_multiplier"])
285
+ baseline = int(market_and_suppliers["baseline_demand"])
286
+ auto_align = bool(market_and_suppliers.get("auto_align_actions", True))
287
+ df = pd.DataFrame(market_and_suppliers["suppliers"])
288
+
289
+ needed = ["name","base_cost_per_unit","current_quality","current_delivery","financial_risk","esg","base_capacity_share"]
290
+ missing = [c for c in needed if c not in df.columns]
291
+ if missing:
292
+ return {"strategy": "error", "allocations": [], "demand_units": 0.0,
293
+ "error": f"Suppliers missing columns: {missing}"}
294
+
295
+ obs = _build_obs(vol, demand_mult, price_mult, df)
296
+ model = _get_model()
297
+ action, _ = model.predict(obs, deterministic=True)
298
+ action = np.asarray(action, dtype=np.float32).reshape(-1)
299
+
300
+ n_sup = len(df)
301
+ if action.size != n_sup:
302
+ if auto_align:
303
+ action = action[:n_sup] if action.size > n_sup else np.pad(action, (0, n_sup - action.size), mode="edge")
304
+ else:
305
+ return {"strategy": "error", "allocations": [], "demand_units": 0.0,
306
+ "error": f"Action length {action.size} != #suppliers {n_sup}"}
307
+
308
+ alloc = _softmax(action)
309
+ k = int((alloc > 1e-2).sum())
310
+ strategy = "single" if k == 1 else ("dual" if k == 2 else "multi")
311
+ demand_units = float(baseline * demand_mult)
312
+
313
+ return {
314
+ "strategy": strategy,
315
+ "allocations": [{"supplier": df.loc[i,"name"], "share": float(alloc[i])} for i in range(n_sup)],
316
+ "demand_units": round(demand_units, 2),
317
+ }
318
+ except Exception as e:
319
+ return {"strategy": "error", "allocations": [], "demand_units": 0.0,
320
+ "error": f"PPO predict error: {e}"}
321
+
322
+ @tool
323
+ def sap_create_po_mock(po: dict) -> dict:
324
+ """MOCK: Create a Purchase Order (does NOT call SAP).
325
+ Args:
326
+ po (dict): PO JSON with a "lines" list like:
327
+ [{"supplier": str, "quantity": float}, ...]
328
+ Returns:
329
+ dict: {"PurchaseOrder": str, "message": str, "echo": dict}
330
+ """
331
+ po_no = f"45{int(time.time())%1_000_000:06d}"
332
+ return {"PurchaseOrder": po_no, "message": "MOCK ONLY — nothing was sent to SAP.", "echo": po}
333
+
334
+ # ===================== LLM SETUP (OpenAI API enabled) =====================
335
+ def get_model():
336
+ """
337
+ Return the LLM object used by smolagents to plan & call tools.
338
+ Uses OpenAI API when USE_RANDOM_MODEL=0 and OPENAI_API_KEY is set.
339
+ """
340
+ if USE_RANDOM and SMOLAGENTS_AVAILABLE:
341
+ try:
342
+ from smolagents import RandomModel
343
+ print("Using RandomModel for agent reasoning")
344
+ return RandomModel()
345
+ except ImportError:
346
+ pass
347
+
348
+ if SMOLAGENTS_AVAILABLE and not USE_RANDOM:
349
+ try:
350
+ # Check if OpenAI API key is available
351
+ openai_key = os.environ.get("OPENAI_API_KEY")
352
+ if not openai_key:
353
+ print("Warning: OPENAI_API_KEY not found in environment. Using fallback model.")
354
+ raise ValueError("No OpenAI API key")
355
+
356
+ from smolagents import LiteLLMModel
357
+ model_id = os.environ.get("LITELLM_MODEL", "gpt-4o-mini")
358
+ print(f"Using OpenAI model: {model_id}")
359
+ return LiteLLMModel(model_id=model_id)
360
+ except ImportError:
361
+ print("LiteLLMModel not available, falling back to RandomModel")
362
+ except Exception as e:
363
+ print(f"Failed to initialize OpenAI model: {e}, falling back to RandomModel")
364
+
365
+ # Fallback options
366
+ if SMOLAGENTS_AVAILABLE:
367
+ try:
368
+ from smolagents import RandomModel
369
+ print("Using RandomModel as fallback")
370
+ return RandomModel()
371
+ except ImportError:
372
+ pass
373
+
374
+ # Final fallback - create a simple mock
375
+ class MockRandomModel:
376
+ def generate(self, prompt, max_tokens=500):
377
+ return "This is a demo response from the mock model."
378
+
379
+ def __call__(self, messages, **kwargs):
380
+ return "This is a demo response from the mock model."
381
+
382
+ print("Using MockRandomModel as final fallback")
383
+ return MockRandomModel()
384
+
385
+ # ===================== MAIN FUNCTIONS (unchanged) =====================
386
+ def build_goal() -> str:
387
+ """
388
+ Fixed 5-step plan with explicit STOP. Uses dict indexing and a fallback path
389
+ if the PPO model file is missing/unloadable.
390
+ """
391
+ suppliers_step = (
392
+ f'Call suppliers_from_csv(csv_path="{SUPPLIERS_CSV}") -> SUPS'
393
+ if SUPPLIERS_CSV else
394
+ 'Call suppliers_synthetic(n=6, seed=123) -> SUPS'
395
+ )
396
+ return f"""
397
+ You are a sourcing ops agent. Follow these steps EXACTLY and STOP after step 5.
398
+ 1) {suppliers_step}
399
+ 2) Call market_signal(volatility="{VOLATILITY}", price_multiplier={PRICE_MULT}, demand_multiplier={DEMAND_MULT}) -> MKT
400
+ 3) Call check_model_tool(model_path="{MODEL_PATH}") -> MC
401
+ - If MC.ok is False:
402
+ # Fallback: use capacity shares to allocate and SKIP the RL step.
403
+ Set REC = {{
404
+ "strategy": "multi",
405
+ "allocations": [{{"supplier": s.name, "share": s.base_capacity_share}} for s in SUPS.suppliers],
406
+ "demand_units": {BASELINE_DEMAND} * {DEMAND_MULT}
407
+ }}
408
+ Else:
409
+ Call rl_recommend_tool(market_and_suppliers={{
410
+ "volatility": MKT.volatility,
411
+ "price_multiplier": MKT.price_multiplier,
412
+ "demand_multiplier": MKT.demand_multiplier,
413
+ "baseline_demand": {BASELINE_DEMAND},
414
+ "suppliers": SUPS.suppliers,
415
+ "auto_align_actions": {"true" if AUTO_ALIGN else "false"}
416
+ }}) -> REC
417
+ 4) Build a PO JSON named PO_JSON:
418
+ {{
419
+ "lines": [{{"supplier": item.supplier if hasattr(item, "supplier") else item["supplier"],
420
+ "quantity": round((REC.demand_units if hasattr(REC, "demand_units") else REC["demand_units"]) *
421
+ (item.share if hasattr(item, "share") else item["share"]), 2)}}
422
+ for item in (REC.allocations if hasattr(REC, "allocations") else REC["allocations"])]
423
+ }}
424
+ 5) Call sap_create_po_mock(po=PO_JSON) and RETURN ITS JSON AS THE FINAL ANSWER.
425
+ DO NOT add extra text. DO NOT run any more steps. STOP AFTER THIS.
426
+ """
427
+
428
+ def main():
429
+ """Main function - robust for Streamlit with OpenAI API"""
430
+ tools = [
431
+ check_model_tool,
432
+ suppliers_from_csv,
433
+ suppliers_synthetic,
434
+ market_signal,
435
+ rl_recommend_tool,
436
+ sap_create_po_mock
437
+ ]
438
+
439
+ try:
440
+ agent = CodeAgent(
441
+ tools=tools,
442
+ model=get_model(),
443
+ add_base_tools=False,
444
+ max_steps=7, # safety cap
445
+ )
446
+ goal = build_goal()
447
+ out = agent.run(goal)
448
+ print(out)
449
+ return out
450
+ except Exception as e:
451
+ print(f"Agent execution failed: {e}")
452
+ return {"error": str(e), "status": "failed"}
453
+
454
+ if __name__ == "__main__":
455
+ main()