F555 commited on
Commit
89a8302
Β·
verified Β·
1 Parent(s): 968015f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -152
app.py CHANGED
@@ -1,25 +1,25 @@
1
  import torch
2
  import numpy as np
3
  import time
 
 
4
  from fastmcp import FastMCP
5
  from ase import Atoms
6
  from ase.build import molecule
7
  import gradio as gr
8
 
9
- # Initialize MCP Server
10
  mcp = FastMCP("RealMACE_Agent")
11
 
12
  # Global State
13
  STATE = {
14
  "model": None,
15
  "config": None,
16
- "batch": None,
17
- "training_logs": []
18
  }
19
 
20
  # --- HELPER FUNCTIONS ---
21
  def get_mace_setup():
22
- """Lazy load MACE imports."""
23
  try:
24
  from mace.models import ScaleShiftMACE
25
  from mace.data import AtomicData, Configuration
@@ -30,13 +30,10 @@ def get_mace_setup():
30
  raise ImportError("MACE not installed. Run: pip install mace-torch")
31
 
32
  def create_dummy_batch(r_max=5.0):
33
- """Creates a water molecule batch for training."""
34
  _, AtomicData, Configuration, torch_geometric, _ = get_mace_setup()
35
-
36
  mol = molecule("H2O")
37
  mol.info["energy"] = -14.0
38
  mol.arrays["forces"] = np.random.randn(3, 3) * 0.1
39
-
40
  config = Configuration(
41
  atomic_numbers=mol.get_atomic_numbers(),
42
  positions=mol.get_positions(),
@@ -45,7 +42,6 @@ def create_dummy_batch(r_max=5.0):
45
  pbc=np.array([False, False, False]),
46
  cell=np.eye(3) * 10.0
47
  )
48
-
49
  z_table = {1: 0, 8: 1}
50
  data_loader = torch_geometric.DataLoader(
51
  dataset=[AtomicData.from_config(config, z_table=z_table, cutoff=r_max)],
@@ -59,35 +55,25 @@ def create_dummy_batch(r_max=5.0):
59
  def init_real_mace_model(r_max: float = 5.0, max_ell: int = 2, hidden_dim: int = 16) -> str:
60
  """Initialize a REAL MACE model."""
61
  ScaleShiftMACE, _, _, _, o3 = get_mace_setup()
62
-
63
  batch = create_dummy_batch(r_max)
64
  STATE["batch"] = batch
65
 
66
  model_config = dict(
67
- r_max=r_max,
68
- num_bessel=8,
69
- num_polynomial_cutoff=5,
70
- max_ell=max_ell,
71
- interaction_cls="RealAgnosticInteractionBlock",
72
- num_interactions=2,
73
- num_elements=2,
74
- hidden_irreps=o3.Irreps(f"{hidden_dim}x0e"),
75
- atomic_energies=np.array([-13.6, -10.0]),
76
- avg_num_neighbors=2,
77
- atomic_numbers=[1, 8]
78
  )
79
-
80
  try:
81
  model = ScaleShiftMACE(**model_config)
82
  STATE["model"] = model
83
- STATE["config"] = model_config
84
  return f"βœ… MACE Model Ready! L_max={max_ell}, r_max={r_max}Γ…"
85
  except Exception as e:
86
  return f"❌ Error: {str(e)}"
87
 
88
  @mcp.tool()
89
- def train_with_trackio(experiment_name: str, epochs: int = 10, learning_rate: float = 0.01) -> str:
90
- """Train the MACE model with Trackio logging."""
91
  try:
92
  import trackio
93
  except ImportError:
@@ -98,152 +84,72 @@ def train_with_trackio(experiment_name: str, epochs: int = 10, learning_rate: fl
98
 
99
  model = STATE["model"]
100
  batch = STATE["batch"]
101
- optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
102
 
103
  try:
 
104
  logger = trackio.Logger(project="Real_MACE_Training", name=experiment_name)
105
  except Exception as e:
106
- return f"❌ Trackio error: {e}"
107
 
108
  model.train()
109
- STATE["training_logs"] = []
 
110
 
111
  for epoch in range(epochs):
112
  optimizer.zero_grad()
113
  out = model(batch.to_dict())
114
-
115
- loss_e = torch.mean((out["energy"] - batch.energy)**2)
116
- loss_f = torch.mean((out["forces"] - batch.forces)**2)
117
- total_loss = loss_e + 10.0 * loss_f
118
-
119
- total_loss.backward()
120
  optimizer.step()
121
 
122
- force_mae = torch.mean(torch.abs(out["forces"] - batch.forces)).item()
123
-
124
- logger.log({
125
  "epoch": epoch,
126
- "total_loss": total_loss.item(),
127
- "force_mae_eV_A": force_mae,
128
- })
 
129
 
130
- STATE["training_logs"].append(f"Epoch {epoch}: Loss={total_loss.item():.5f}")
131
- time.sleep(0.05)
132
-
133
- return f"πŸš€ Training done! Final Loss: {total_loss.item():.6f}\n" + "\n".join(STATE["training_logs"][-5:])
134
 
135
- @mcp.tool()
136
- def check_equivariance(rotation_degrees: float = 45.0) -> str:
137
- """Test E(3)-equivariance."""
138
- if STATE["model"] is None:
139
- return "⚠️ No model found!"
140
-
141
- model = STATE["model"]
142
- batch = STATE["batch"]
143
-
144
- model.eval()
145
- with torch.no_grad():
146
- out_orig = model(batch.to_dict())
147
- forces_orig = out_orig["forces"].clone()
148
-
149
- angle = np.radians(rotation_degrees)
150
- rot_matrix = torch.tensor([
151
- [np.cos(angle), -np.sin(angle), 0],
152
- [np.sin(angle), np.cos(angle), 0],
153
- [0, 0, 1]
154
- ], dtype=torch.float32)
155
-
156
- batch_rot = batch.clone()
157
- batch_rot.positions = torch.matmul(batch.positions, rot_matrix.T)
158
-
159
- with torch.no_grad():
160
- out_rot = model(batch_rot.to_dict())
161
- forces_rot = out_rot["forces"]
162
-
163
- forces_orig_rotated = torch.matmul(forces_orig, rot_matrix.T)
164
- equivariance_error = torch.mean(torch.abs(forces_rot - forces_orig_rotated)).item()
165
-
166
- return f"πŸ§ͺ Equivariance Error: {equivariance_error:.2e} eV/Γ…\n{'βœ… PASS' if equivariance_error < 1e-4 else '⚠️ High error'}"
167
 
168
- # --- GRADIO UI ---
169
- def create_ui():
170
- """Create the Gradio interface."""
171
- with gr.Blocks(title="Equivariant Alchemist - MACE Training Lab") as demo:
172
- gr.Markdown("""
173
- # πŸ§ͺ Equivariant Alchemist - MACE Training Lab
174
 
175
- This app combines an **MCP Server** for AI agents with **Trackio** experiment tracking.
 
 
 
 
 
176
  """)
177
 
178
- with gr.Tabs():
179
- with gr.Tab("πŸ“Š View Trackio Dashboard"):
180
- gr.Markdown("""
181
- ### Live Training Metrics
182
-
183
- To view your training metrics, open the Trackio dashboard in a separate window:
184
-
185
- **Option 1: Command Line**
186
- ```
187
- trackio show --project "Real_MACE_Training"
188
- ```
189
-
190
- **Option 2: Python**
191
- ```
192
- import trackio
193
- trackio.show(project="Real_MACE_Training")
194
- ```
195
-
196
- The dashboard will automatically update as training runs complete.
197
- """)
198
-
199
- gr.HTML("""
200
- <iframe
201
- src="/trackio"
202
- width="100%"
203
- height="800px"
204
- frameborder="0"
205
- style="border-radius: 8px;"
206
- ></iframe>
207
- """)
208
-
209
- with gr.Tab("πŸ”Œ MCP Server Info"):
210
- gr.Markdown(f"""
211
- ### MCP Server Status: βœ… Running
212
-
213
- **Server URL:** Access at `/sse` endpoint
214
-
215
- **Available Tools:**
216
- 1. `init_real_mace_model(r_max, max_ell, hidden_dim)` - Initialize MACE architecture
217
- 2. `train_with_trackio(experiment_name, epochs, learning_rate)` - Train with live logging
218
- 3. `check_equivariance(rotation_degrees)` - Test rotation symmetry
219
-
220
- **Connect from Claude Desktop:**
221
- ```
222
- {{
223
- "mcpServers": {{
224
- "mace_trainer": {{
225
- "url": "YOUR_SPACE_URL/sse"
226
- }}
227
- }}
228
- }}
229
- ```
230
-
231
- **Example Prompts:**
232
- - *"Initialize a MACE model with max_ell=2 and r_max=5.0"*
233
- - *"Train for 20 epochs with learning rate 0.001"*
234
- - *"Check if the model is equivariant by rotating 90 degrees"*
235
- """)
236
-
237
- return demo
238
 
239
  if __name__ == "__main__":
240
- print("Starting MACE-MCP Server with Trackio Integration...")
241
-
242
- # Create and launch the Gradio UI with MCP server
243
- demo = create_ui()
244
- demo.launch(
245
- server_name="0.0.0.0",
246
- server_port=7860,
247
- share=False,
248
- mcp_server=mcp # This enables MCP on the /sse endpoint
249
- )
 
 
 
 
 
 
 
1
  import torch
2
  import numpy as np
3
  import time
4
+ import threading
5
+ import uvicorn
6
  from fastmcp import FastMCP
7
  from ase import Atoms
8
  from ase.build import molecule
9
  import gradio as gr
10
 
11
+ # --- 1. MCP SERVER SETUP ---
12
  mcp = FastMCP("RealMACE_Agent")
13
 
14
  # Global State
15
  STATE = {
16
  "model": None,
17
  "config": None,
18
+ "batch": None
 
19
  }
20
 
21
  # --- HELPER FUNCTIONS ---
22
  def get_mace_setup():
 
23
  try:
24
  from mace.models import ScaleShiftMACE
25
  from mace.data import AtomicData, Configuration
 
30
  raise ImportError("MACE not installed. Run: pip install mace-torch")
31
 
32
  def create_dummy_batch(r_max=5.0):
 
33
  _, AtomicData, Configuration, torch_geometric, _ = get_mace_setup()
 
34
  mol = molecule("H2O")
35
  mol.info["energy"] = -14.0
36
  mol.arrays["forces"] = np.random.randn(3, 3) * 0.1
 
37
  config = Configuration(
38
  atomic_numbers=mol.get_atomic_numbers(),
39
  positions=mol.get_positions(),
 
42
  pbc=np.array([False, False, False]),
43
  cell=np.eye(3) * 10.0
44
  )
 
45
  z_table = {1: 0, 8: 1}
46
  data_loader = torch_geometric.DataLoader(
47
  dataset=[AtomicData.from_config(config, z_table=z_table, cutoff=r_max)],
 
55
  def init_real_mace_model(r_max: float = 5.0, max_ell: int = 2, hidden_dim: int = 16) -> str:
56
  """Initialize a REAL MACE model."""
57
  ScaleShiftMACE, _, _, _, o3 = get_mace_setup()
 
58
  batch = create_dummy_batch(r_max)
59
  STATE["batch"] = batch
60
 
61
  model_config = dict(
62
+ r_max=r_max, num_bessel=8, num_polynomial_cutoff=5, max_ell=max_ell,
63
+ interaction_cls="RealAgnosticInteractionBlock", num_interactions=2, num_elements=2,
64
+ hidden_irreps=o3.Irreps(f"{hidden_dim}x0e"), atomic_energies=np.array([-13.6, -10.0]),
65
+ avg_num_neighbors=2, atomic_numbers=[1, 8]
 
 
 
 
 
 
 
66
  )
 
67
  try:
68
  model = ScaleShiftMACE(**model_config)
69
  STATE["model"] = model
 
70
  return f"βœ… MACE Model Ready! L_max={max_ell}, r_max={r_max}Γ…"
71
  except Exception as e:
72
  return f"❌ Error: {str(e)}"
73
 
74
  @mcp.tool()
75
+ def train_with_trackio(experiment_name: str, epochs: int = 10) -> str:
76
+ """Train with Trackio logging."""
77
  try:
78
  import trackio
79
  except ImportError:
 
84
 
85
  model = STATE["model"]
86
  batch = STATE["batch"]
87
+ optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
88
 
89
  try:
90
+ # Check if we are in a Space with OAuth
91
  logger = trackio.Logger(project="Real_MACE_Training", name=experiment_name)
92
  except Exception as e:
93
+ return f"❌ Trackio connection failed: {e}"
94
 
95
  model.train()
96
+ logs = []
97
+ start = time.time()
98
 
99
  for epoch in range(epochs):
100
  optimizer.zero_grad()
101
  out = model(batch.to_dict())
102
+ loss = torch.mean((out["energy"] - batch.energy)**2) + 10.0 * torch.mean((out["forces"] - batch.forces)**2)
103
+ loss.backward()
 
 
 
 
104
  optimizer.step()
105
 
106
+ metrics = {
 
 
107
  "epoch": epoch,
108
+ "total_loss": loss.item(),
109
+ "wall_time": time.time() - start
110
+ }
111
+ logger.log(metrics)
112
 
113
+ if epoch % 5 == 0:
114
+ logs.append(f"Ep {epoch}: Loss={loss.item():.4f}")
115
+ time.sleep(0.05)
 
116
 
117
+ return "πŸš€ Training Done! Check Dashboard.\n" + "\n".join(logs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
+ # --- 2. DASHBOARD UI (Separate Thread) ---
120
+ def launch_dashboard():
121
+ """Launches a Gradio UI that serves as the Dashboard Viewer"""
122
+ with gr.Blocks(title="Equivariant Chem Scout") as demo:
123
+ gr.Markdown("# πŸ§ͺ Equivariant Chem Scout (Dashboard)")
124
+ gr.Markdown("To view training results, open the **Trackio** dashboard below.")
125
 
126
+ # Option A: If running locally, just show instructions
127
+ gr.Markdown("""
128
+ ### How to view graphs:
129
+ The Trackio dashboard runs separately.
130
+ If you are running locally, type: `trackio show` in your terminal.
131
+ If you are on Hugging Face Spaces, we need to launch the Trackio server.
132
  """)
133
 
134
+ # Option B: Attempt to embed (Experimental)
135
+ # Note: Trackio doesn't have a verified embed widget yet, so we provide instructions.
136
+
137
+ demo.launch(server_name="0.0.0.0", server_port=7860, prevent_thread_lock=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
  if __name__ == "__main__":
140
+ print("--- STARTING SERVICES ---")
141
+
142
+ # 1. Launch the UI (Dashboard) in a background thread on port 7860
143
+ print("1. Launching Gradio Dashboard on port 7860...")
144
+ launch_dashboard()
145
+
146
+ # 2. Run the MCP Server on the main thread (port 8000 or SSE)
147
+ print("2. Starting MCP Server (SSE Transport)...")
148
+ # Hugging Face Spaces expects the main process to listen on port 7860 usually,
149
+ # but for MCP we need to expose the SSE endpoint.
150
+ # TRICK: We let Gradio take 7860 (so the Space shows "Running"),
151
+ # and we run MCP on 8000. You connect to the Space URL via SSE proxying if configured,
152
+ # or you use this Space *only* as a dashboard and run the MCP logic locally connecting to it.
153
+
154
+ # However, since you want the Space to BE the MCP server:
155
+ mcp.run(transport="sse")