F555 commited on
Commit
4af55fb
Β·
verified Β·
1 Parent(s): 89a8302

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +139 -74
app.py CHANGED
@@ -1,25 +1,27 @@
1
  import torch
2
  import numpy as np
3
  import time
4
- import threading
5
  import uvicorn
 
6
  from fastmcp import FastMCP
 
7
  from ase import Atoms
8
  from ase.build import molecule
9
- import gradio as gr
10
 
11
- # --- 1. MCP SERVER SETUP ---
12
- mcp = FastMCP("RealMACE_Agent")
13
 
14
- # Global State
15
  STATE = {
16
  "model": None,
17
  "config": None,
18
- "batch": None
 
19
  }
20
 
21
  # --- HELPER FUNCTIONS ---
22
  def get_mace_setup():
 
23
  try:
24
  from mace.models import ScaleShiftMACE
25
  from mace.data import AtomicData, Configuration
@@ -27,13 +29,16 @@ def get_mace_setup():
27
  from e3nn import o3
28
  return ScaleShiftMACE, AtomicData, Configuration, torch_geometric, o3
29
  except ImportError:
30
- raise ImportError("MACE not installed. Run: pip install mace-torch")
31
 
32
  def create_dummy_batch(r_max=5.0):
33
  _, AtomicData, Configuration, torch_geometric, _ = get_mace_setup()
 
 
34
  mol = molecule("H2O")
35
  mol.info["energy"] = -14.0
36
  mol.arrays["forces"] = np.random.randn(3, 3) * 0.1
 
37
  config = Configuration(
38
  atomic_numbers=mol.get_atomic_numbers(),
39
  positions=mol.get_positions(),
@@ -42,6 +47,7 @@ def create_dummy_batch(r_max=5.0):
42
  pbc=np.array([False, False, False]),
43
  cell=np.eye(3) * 10.0
44
  )
 
45
  z_table = {1: 0, 8: 1}
46
  data_loader = torch_geometric.DataLoader(
47
  dataset=[AtomicData.from_config(config, z_table=z_table, cutoff=r_max)],
@@ -51,105 +57,164 @@ def create_dummy_batch(r_max=5.0):
51
  return next(iter(data_loader))
52
 
53
  # --- MCP TOOLS ---
 
54
  @mcp.tool()
55
- def init_real_mace_model(r_max: float = 5.0, max_ell: int = 2, hidden_dim: int = 16) -> str:
56
- """Initialize a REAL MACE model."""
 
 
 
 
 
57
  ScaleShiftMACE, _, _, _, o3 = get_mace_setup()
 
58
  batch = create_dummy_batch(r_max)
59
  STATE["batch"] = batch
60
 
 
61
  model_config = dict(
62
  r_max=r_max, num_bessel=8, num_polynomial_cutoff=5, max_ell=max_ell,
63
  interaction_cls="RealAgnosticInteractionBlock", num_interactions=2, num_elements=2,
64
- hidden_irreps=o3.Irreps(f"{hidden_dim}x0e"), atomic_energies=np.array([-13.6, -10.0]),
65
  avg_num_neighbors=2, atomic_numbers=[1, 8]
66
  )
 
67
  try:
68
  model = ScaleShiftMACE(**model_config)
69
  STATE["model"] = model
70
- return f"βœ… MACE Model Ready! L_max={max_ell}, r_max={r_max}Γ…"
 
 
71
  except Exception as e:
72
  return f"❌ Error: {str(e)}"
73
 
74
  @mcp.tool()
75
- def train_with_trackio(experiment_name: str, epochs: int = 10) -> str:
76
- """Train with Trackio logging."""
77
- try:
78
- import trackio
79
- except ImportError:
80
- return "❌ Trackio not installed"
81
-
82
  if STATE["model"] is None:
83
- return "⚠️ Run 'init_real_mace_model' first!"
84
 
85
  model = STATE["model"]
86
  batch = STATE["batch"]
87
- optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
88
 
89
- try:
90
- # Check if we are in a Space with OAuth
91
- logger = trackio.Logger(project="Real_MACE_Training", name=experiment_name)
92
- except Exception as e:
93
- return f"❌ Trackio connection failed: {e}"
94
-
95
  model.train()
96
- logs = []
97
- start = time.time()
98
 
 
99
  for epoch in range(epochs):
100
  optimizer.zero_grad()
101
  out = model(batch.to_dict())
 
102
  loss = torch.mean((out["energy"] - batch.energy)**2) + 10.0 * torch.mean((out["forces"] - batch.forces)**2)
103
  loss.backward()
104
  optimizer.step()
105
 
106
- metrics = {
107
- "epoch": epoch,
108
- "total_loss": loss.item(),
109
- "wall_time": time.time() - start
110
- }
111
- logger.log(metrics)
112
-
113
- if epoch % 5 == 0:
114
- logs.append(f"Ep {epoch}: Loss={loss.item():.4f}")
115
- time.sleep(0.05)
116
-
117
- return "πŸš€ Training Done! Check Dashboard.\n" + "\n".join(logs)
118
-
119
- # --- 2. DASHBOARD UI (Separate Thread) ---
120
- def launch_dashboard():
121
- """Launches a Gradio UI that serves as the Dashboard Viewer"""
122
- with gr.Blocks(title="Equivariant Chem Scout") as demo:
123
- gr.Markdown("# πŸ§ͺ Equivariant Chem Scout (Dashboard)")
124
- gr.Markdown("To view training results, open the **Trackio** dashboard below.")
125
-
126
- # Option A: If running locally, just show instructions
127
- gr.Markdown("""
128
- ### How to view graphs:
129
- The Trackio dashboard runs separately.
130
- If you are running locally, type: `trackio show` in your terminal.
131
- If you are on Hugging Face Spaces, we need to launch the Trackio server.
132
- """)
133
 
134
- # Option B: Attempt to embed (Experimental)
135
- # Note: Trackio doesn't have a verified embed widget yet, so we provide instructions.
136
 
137
- demo.launch(server_name="0.0.0.0", server_port=7860, prevent_thread_lock=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
 
139
  if __name__ == "__main__":
140
- print("--- STARTING SERVICES ---")
141
-
142
- # 1. Launch the UI (Dashboard) in a background thread on port 7860
143
- print("1. Launching Gradio Dashboard on port 7860...")
144
- launch_dashboard()
145
-
146
- # 2. Run the MCP Server on the main thread (port 8000 or SSE)
147
- print("2. Starting MCP Server (SSE Transport)...")
148
- # Hugging Face Spaces expects the main process to listen on port 7860 usually,
149
- # but for MCP we need to expose the SSE endpoint.
150
- # TRICK: We let Gradio take 7860 (so the Space shows "Running"),
151
- # and we run MCP on 8000. You connect to the Space URL via SSE proxying if configured,
152
- # or you use this Space *only* as a dashboard and run the MCP logic locally connecting to it.
153
-
154
- # However, since you want the Space to BE the MCP server:
155
- mcp.run(transport="sse")
 
1
  import torch
2
  import numpy as np
3
  import time
 
4
  import uvicorn
5
+ from fastapi import FastAPI
6
  from fastmcp import FastMCP
7
+ import gradio as gr
8
  from ase import Atoms
9
  from ase.build import molecule
 
10
 
11
+ # --- 1. DEFINE MCP SERVER ---
12
+ mcp = FastMCP("Equivariant_Chem_Scout")
13
 
14
+ # Global State for Persisting Models across Tool Calls
15
  STATE = {
16
  "model": None,
17
  "config": None,
18
+ "batch": None,
19
+ "logs": []
20
  }
21
 
22
  # --- HELPER FUNCTIONS ---
23
  def get_mace_setup():
24
+ """Lazy load MACE to avoid startup crashes if deps are missing"""
25
  try:
26
  from mace.models import ScaleShiftMACE
27
  from mace.data import AtomicData, Configuration
 
29
  from e3nn import o3
30
  return ScaleShiftMACE, AtomicData, Configuration, torch_geometric, o3
31
  except ImportError:
32
+ raise ImportError("MACE not installed. Please check requirements.txt")
33
 
34
  def create_dummy_batch(r_max=5.0):
35
  _, AtomicData, Configuration, torch_geometric, _ = get_mace_setup()
36
+
37
+ # Create dummy water molecule
38
  mol = molecule("H2O")
39
  mol.info["energy"] = -14.0
40
  mol.arrays["forces"] = np.random.randn(3, 3) * 0.1
41
+
42
  config = Configuration(
43
  atomic_numbers=mol.get_atomic_numbers(),
44
  positions=mol.get_positions(),
 
47
  pbc=np.array([False, False, False]),
48
  cell=np.eye(3) * 10.0
49
  )
50
+
51
  z_table = {1: 0, 8: 1}
52
  data_loader = torch_geometric.DataLoader(
53
  dataset=[AtomicData.from_config(config, z_table=z_table, cutoff=r_max)],
 
57
  return next(iter(data_loader))
58
 
59
  # --- MCP TOOLS ---
60
+
61
  @mcp.tool()
62
+ def init_mace_model(r_max: float = 5.0, max_ell: int = 2) -> str:
63
+ """
64
+ Initialize a MACE model with specific symmetry settings.
65
+ Args:
66
+ r_max: Cutoff radius (Angstroms)
67
+ max_ell: 0 = Invariant only, 2 = Equivariant (Vectors)
68
+ """
69
  ScaleShiftMACE, _, _, _, o3 = get_mace_setup()
70
+
71
  batch = create_dummy_batch(r_max)
72
  STATE["batch"] = batch
73
 
74
+ # simplified MACE config
75
  model_config = dict(
76
  r_max=r_max, num_bessel=8, num_polynomial_cutoff=5, max_ell=max_ell,
77
  interaction_cls="RealAgnosticInteractionBlock", num_interactions=2, num_elements=2,
78
+ hidden_irreps=o3.Irreps("16x0e"), atomic_energies=np.array([-13.6, -10.0]),
79
  avg_num_neighbors=2, atomic_numbers=[1, 8]
80
  )
81
+
82
  try:
83
  model = ScaleShiftMACE(**model_config)
84
  STATE["model"] = model
85
+ STATE["config"] = model_config
86
+ STATE["logs"] = [] # Reset logs
87
+ return f"βœ… MACE Model Initialized! (L_max={max_ell}, R_max={r_max})"
88
  except Exception as e:
89
  return f"❌ Error: {str(e)}"
90
 
91
  @mcp.tool()
92
+ def train_model(epochs: int = 10, learning_rate: float = 0.01) -> str:
93
+ """Train the initialized model and log results."""
 
 
 
 
 
94
  if STATE["model"] is None:
95
+ return "⚠️ Run 'init_mace_model' first!"
96
 
97
  model = STATE["model"]
98
  batch = STATE["batch"]
99
+ optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
100
 
 
 
 
 
 
 
101
  model.train()
102
+ run_logs = []
 
103
 
104
+ # Simple Training Loop
105
  for epoch in range(epochs):
106
  optimizer.zero_grad()
107
  out = model(batch.to_dict())
108
+
109
  loss = torch.mean((out["energy"] - batch.energy)**2) + 10.0 * torch.mean((out["forces"] - batch.forces)**2)
110
  loss.backward()
111
  optimizer.step()
112
 
113
+ # Log for Dashboard
114
+ log_entry = f"Epoch {epoch}: Loss={loss.item():.4f}"
115
+ run_logs.append(log_entry)
116
+ STATE["logs"].append(log_entry) # Append to global state for UI viewing
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
+ time.sleep(0.05) # Simulate work
 
119
 
120
+ return f"πŸš€ Training Complete!\nFinal Loss: {loss.item():.4f}\n" + "\n".join(run_logs[-5:])
121
+
122
+ @mcp.tool()
123
+ def check_equivariance(rotation_degrees: float = 90.0) -> str:
124
+ """Check if the model respects E(3) symmetry."""
125
+ if STATE["model"] is None: return "⚠️ No model found."
126
+
127
+ model = STATE["model"]
128
+ batch = STATE["batch"]
129
+ model.eval()
130
+
131
+ # 1. Original Pred
132
+ out_orig = model(batch.to_dict())
133
+ f_orig = out_orig["forces"]
134
+
135
+ # 2. Rotated Input
136
+ angle = np.radians(rotation_degrees)
137
+ R = torch.tensor([[np.cos(angle), -np.sin(angle), 0], [np.sin(angle), np.cos(angle), 0], [0,0,1]], dtype=torch.float32)
138
+
139
+ batch_rot = batch.clone()
140
+ batch_rot.positions = torch.matmul(batch.positions, R.T)
141
+
142
+ # 3. Rotated Pred
143
+ out_rot = model(batch_rot.to_dict())
144
+ f_rot = out_rot["forces"]
145
+
146
+ # 4. Compare: Rot(F_orig) vs F_rot
147
+ f_orig_rot = torch.matmul(f_orig, R.T)
148
+ error = torch.mean(torch.abs(f_rot - f_orig_rot)).item()
149
+
150
+ return f"πŸ§ͺ Equivariance Error: {error:.2e} (Pass: {error < 1e-4})"
151
+
152
+ # --- 2. DEFINE GRADIO DASHBOARD ---
153
+ def get_latest_logs():
154
+ """Refresh function for the UI"""
155
+ if not STATE["logs"]:
156
+ return "No training logs yet. Ask the Agent to train a model!"
157
+ return "\n".join(STATE["logs"])
158
+
159
+ with gr.Blocks(title="Equivariant Chem Scout") as demo:
160
+ gr.Markdown("# πŸ§ͺ Equivariant Chem Scout")
161
+ gr.Markdown("""
162
+ ### Status: 🟒 Online
163
+ **Connect via Claude Desktop:**
164
+ ```
165
+ {
166
+ "mcpServers": {
167
+ "chem_scout": {
168
+ "url": "https://YOUR-SPACE-URL.hf.space/sse"
169
+ }
170
+ }
171
+ }
172
+ ```
173
+ """)
174
+
175
+ with gr.Row():
176
+ with gr.Column():
177
+ gr.Markdown("### πŸ“Š Live Training Logs")
178
+ log_display = gr.TextArea(label="Training Output", interactive=False, lines=20)
179
+ refresh_btn = gr.Button("πŸ”„ Refresh Logs")
180
+
181
+ refresh_btn.click(fn=get_latest_logs, outputs=log_display)
182
+
183
+ # Auto-refresh every 2 seconds
184
+ demo.load(fn=get_latest_logs, outputs=log_display, every=2.0)
185
+
186
+ # --- 3. ASSEMBLE FASTAPI APP ---
187
+ # This is the magic glue that makes it work on Spaces
188
+ app = FastAPI()
189
+
190
+ # Mount the MCP Server at /sse
191
+ # FastMCP provides a method to attach itself to an existing FastAPI app
192
+ from sse_starlette.sse import EventSourceResponse
193
+
194
+ @app.get("/sse")
195
+ async def handle_sse(request):
196
+ return EventSourceResponse(mcp.sse_handler(request))
197
+
198
+ @app.post("/messages")
199
+ async def handle_messages(request):
200
+ return await mcp.handle_post_message(request)
201
+
202
+ # IMPORTANT: `mcp.mount_to_fastapi` is not always available in older versions,
203
+ # so we can use the manual mounting above OR use the built-in if available.
204
+ # Let's try the safest built-in method if it exists, or fallback to manual.
205
+
206
+ try:
207
+ # Newer FastMCP versions
208
+ mcp.mount_to_fastapi(app, path="/sse")
209
+ except AttributeError:
210
+ # Fallback if mount_to_fastapi doesn't exist (older versions)
211
+ pass
212
+
213
+ # Mount Gradio at the root
214
+ app = gr.mount_gradio_app(app, demo, path="/")
215
 
216
+ # --- 4. ENTRY POINT ---
217
  if __name__ == "__main__":
218
+ # Hugging Face Spaces will run this with: uvicorn app:app --host 0.0.0.0 --port 7860
219
+ # But for local testing:
220
+ uvicorn.run(app, host="0.0.0.0", port=7860)