F555 commited on
Commit
3a15d3b
·
verified ·
1 Parent(s): f08a9df

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +129 -33
app.py CHANGED
@@ -5,9 +5,8 @@ from fastmcp import FastMCP
5
  from ase import Atoms
6
  from ase.build import molecule
7
 
8
- # Initialize MCP Server
9
- # We do NOT import trackio or mace globally to prevent startup crashes
10
- mcp = FastMCP("RealMACE_Agent", dependencies=["mace-torch", "trackio", "ase", "e3nn"])
11
 
12
  # Global State to share data between tools
13
  STATE = {
@@ -35,7 +34,7 @@ def create_dummy_batch(r_max=5.0):
35
  # Create dummy water
36
  mol = molecule("H2O")
37
  mol.info["energy"] = -14.0 # Dummy target energy (eV)
38
- mol.arrays["forces"] = np.random.randn(3, 3) # Dummy target forces
39
 
40
  config = Configuration(
41
  atomic_numbers=mol.get_atomic_numbers(),
@@ -49,9 +48,9 @@ def create_dummy_batch(r_max=5.0):
49
  # Convert to batch
50
  z_table = {1: 0, 8: 1} # Map H->0, O->1 for simple one-hot
51
  data_loader = torch_geometric.DataLoader(
52
- dataset=[AtomicData.from_config(config, z_table=z_table, cutoff=r_max, model="MACE")],
53
  batch_size=1,
54
- shuffle=True
55
  )
56
  return next(iter(data_loader))
57
 
@@ -61,10 +60,14 @@ def create_dummy_batch(r_max=5.0):
61
  def init_real_mace_model(r_max: float = 5.0, max_ell: int = 2, hidden_dim: int = 16) -> str:
62
  """
63
  Initializes a REAL MACE model and stores it in memory.
 
64
  Args:
65
- r_max: Cutoff radius (Angstroms)
66
- max_ell: Max spherical harmonic degree (2=vectors, 0=invariant)
67
- hidden_dim: Size of the embedding vectors
 
 
 
68
  """
69
  ScaleShiftMACE, _, _, _, o3 = get_mace_setup()
70
 
@@ -83,7 +86,7 @@ def init_real_mace_model(r_max: float = 5.0, max_ell: int = 2, hidden_dim: int =
83
  num_interactions=2,
84
  num_elements=2, # H and O
85
  hidden_irreps=o3.Irreps(f"{hidden_dim}x0e"),
86
- atomic_energies=np.array([-13.6, -10.0]), # Dummy average energies
87
  avg_num_neighbors=2,
88
  atomic_numbers=[1, 8]
89
  )
@@ -93,21 +96,39 @@ def init_real_mace_model(r_max: float = 5.0, max_ell: int = 2, hidden_dim: int =
93
  model = ScaleShiftMACE(**model_config)
94
  STATE["model"] = model
95
  STATE["config"] = model_config
96
- return f"✅ MACE Model Initialized! (L={max_ell}, R={r_max}). Ready to train."
 
 
 
 
 
 
 
 
 
97
  except Exception as e:
98
- return f"Error initializing MACE: {str(e)}"
99
 
100
  @mcp.tool()
101
- def train_with_trackio(experiment_name: str, epochs: int = 10) -> str:
102
  """
103
- Trains the stored MACE model and logs to Trackio.
104
- MUST run init_real_mace_model first.
 
 
 
 
 
 
 
 
 
105
  """
106
  # 1. Lazy Import Trackio to prevent startup crash
107
  try:
108
  import trackio
109
  except ImportError:
110
- return "Trackio not installed. Run: pip install trackio"
111
 
112
  # 2. Check if model exists
113
  if STATE["model"] is None:
@@ -117,14 +138,13 @@ def train_with_trackio(experiment_name: str, epochs: int = 10) -> str:
117
  batch = STATE["batch"]
118
 
119
  # 3. Setup Optimizer
120
- optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
121
 
122
  # 4. Setup Trackio
123
  try:
124
- # Trackio might fail if OAuth isn't set up in Space, catch it gracefully
125
  logger = trackio.Logger(project="Real_MACE_Training", name=experiment_name)
126
  except Exception as e:
127
- return f"❌ Trackio Connection Failed: {e}. (Check 'hf_oauth: true' in README?)"
128
 
129
  # 5. Training Loop
130
  model.train()
@@ -138,36 +158,112 @@ def train_with_trackio(experiment_name: str, epochs: int = 10) -> str:
138
  # MACE Forward Pass
139
  out = model(batch.to_dict())
140
 
141
- # Loss Calc
142
  loss_e = torch.mean((out["energy"] - batch.energy)**2)
143
  loss_f = torch.mean((out["forces"] - batch.forces)**2)
144
- total_loss = loss_e + 10.0 * loss_f
145
 
146
  total_loss.backward()
147
  optimizer.step()
148
 
149
- # Log metrics
 
 
 
 
150
  metrics = {
151
  "epoch": epoch,
152
  "total_loss": total_loss.item(),
153
- "force_mae": torch.mean(torch.abs(out["forces"] - batch.forces)).item(),
154
- "wall_time": time.time() - start_time
 
155
  }
156
 
157
- # Push to Trackio
158
  logger.log(metrics)
159
 
160
- if epoch % 5 == 0:
161
- log_summary.append(f"Epoch {epoch}: Loss={total_loss.item():.4f}")
162
- time.sleep(0.05) # Yield slightly
 
 
 
163
 
164
  return (
165
- f"🚀 **Training Complete!**\n"
166
- f"Experiment: {experiment_name}\n"
167
- f"Final Loss: {total_loss.item():.5f}\n"
168
- f"Check the Trackio tab for the live graphs!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  )
170
 
171
  if __name__ == "__main__":
172
  print("Starting MACE-MCP Server...")
173
- mcp.run()
 
 
5
  from ase import Atoms
6
  from ase.build import molecule
7
 
8
+ # Initialize MCP Server (dependencies removed - use requirements.txt instead)
9
+ mcp = FastMCP("RealMACE_Agent")
 
10
 
11
  # Global State to share data between tools
12
  STATE = {
 
34
  # Create dummy water
35
  mol = molecule("H2O")
36
  mol.info["energy"] = -14.0 # Dummy target energy (eV)
37
+ mol.arrays["forces"] = np.random.randn(3, 3) * 0.1 # Dummy target forces
38
 
39
  config = Configuration(
40
  atomic_numbers=mol.get_atomic_numbers(),
 
48
  # Convert to batch
49
  z_table = {1: 0, 8: 1} # Map H->0, O->1 for simple one-hot
50
  data_loader = torch_geometric.DataLoader(
51
+ dataset=[AtomicData.from_config(config, z_table=z_table, cutoff=r_max)],
52
  batch_size=1,
53
+ shuffle=False
54
  )
55
  return next(iter(data_loader))
56
 
 
60
  def init_real_mace_model(r_max: float = 5.0, max_ell: int = 2, hidden_dim: int = 16) -> str:
61
  """
62
  Initializes a REAL MACE model and stores it in memory.
63
+
64
  Args:
65
+ r_max: Cutoff radius in Angstroms (default 5.0)
66
+ max_ell: Maximum spherical harmonic degree - 0=scalars only, 2=include vectors (default 2)
67
+ hidden_dim: Size of the hidden embedding vectors (default 16)
68
+
69
+ Returns:
70
+ Status message with model configuration
71
  """
72
  ScaleShiftMACE, _, _, _, o3 = get_mace_setup()
73
 
 
86
  num_interactions=2,
87
  num_elements=2, # H and O
88
  hidden_irreps=o3.Irreps(f"{hidden_dim}x0e"),
89
+ atomic_energies=np.array([-13.6, -10.0]), # Dummy average energies for H and O
90
  avg_num_neighbors=2,
91
  atomic_numbers=[1, 8]
92
  )
 
96
  model = ScaleShiftMACE(**model_config)
97
  STATE["model"] = model
98
  STATE["config"] = model_config
99
+ return (
100
+ f"✅ **MACE Model Initialized Successfully!**\n\n"
101
+ f"Configuration:\n"
102
+ f"- Cutoff Radius (r_max): {r_max} Å\n"
103
+ f"- Max Spherical Harmonic Degree (L_max): {max_ell}\n"
104
+ f"- Hidden Dimension: {hidden_dim}\n"
105
+ f"- Interaction Blocks: 2\n"
106
+ f"- Elements: H, O\n\n"
107
+ f"Model is ready for training. Use 'train_with_trackio' next."
108
+ )
109
  except Exception as e:
110
+ return f"Error initializing MACE: {str(e)}"
111
 
112
  @mcp.tool()
113
+ def train_with_trackio(experiment_name: str, epochs: int = 10, learning_rate: float = 0.01) -> str:
114
  """
115
+ Trains the stored MACE model and logs metrics to Trackio.
116
+
117
+ Args:
118
+ experiment_name: Name for this training run in Trackio
119
+ epochs: Number of training epochs (default 10)
120
+ learning_rate: Optimizer learning rate (default 0.01)
121
+
122
+ Returns:
123
+ Training summary with final loss metrics
124
+
125
+ Note: Must run 'init_real_mace_model' first to create a model.
126
  """
127
  # 1. Lazy Import Trackio to prevent startup crash
128
  try:
129
  import trackio
130
  except ImportError:
131
+ return "Trackio not installed. Run: pip install trackio"
132
 
133
  # 2. Check if model exists
134
  if STATE["model"] is None:
 
138
  batch = STATE["batch"]
139
 
140
  # 3. Setup Optimizer
141
+ optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
142
 
143
  # 4. Setup Trackio
144
  try:
 
145
  logger = trackio.Logger(project="Real_MACE_Training", name=experiment_name)
146
  except Exception as e:
147
+ return f"❌ Trackio Connection Failed: {e}\n(Hint: Add 'hf_oauth: true' to README.md if running on HF Space)"
148
 
149
  # 5. Training Loop
150
  model.train()
 
158
  # MACE Forward Pass
159
  out = model(batch.to_dict())
160
 
161
+ # Loss Calculation (Energy MSE + Force MSE)
162
  loss_e = torch.mean((out["energy"] - batch.energy)**2)
163
  loss_f = torch.mean((out["forces"] - batch.forces)**2)
164
+ total_loss = loss_e + 10.0 * loss_f # Weight forces 10x more
165
 
166
  total_loss.backward()
167
  optimizer.step()
168
 
169
+ # Calculate MAE metrics for interpretability
170
+ force_mae = torch.mean(torch.abs(out["forces"] - batch.forces)).item()
171
+ energy_mae = torch.abs(out["energy"] - batch.energy).mean().item()
172
+
173
+ # Log metrics to Trackio
174
  metrics = {
175
  "epoch": epoch,
176
  "total_loss": total_loss.item(),
177
+ "energy_mae_eV": energy_mae,
178
+ "force_mae_eV_A": force_mae,
179
+ "wall_time_sec": time.time() - start_time
180
  }
181
 
 
182
  logger.log(metrics)
183
 
184
+ if epoch % 5 == 0 or epoch == epochs - 1:
185
+ log_summary.append(
186
+ f"Epoch {epoch:3d}: Loss={total_loss.item():.5f} | "
187
+ f"Force MAE={force_mae:.5f} eV/Å"
188
+ )
189
+ time.sleep(0.05) # Small delay for visualization
190
 
191
  return (
192
+ f"🚀 **Training Complete!**\n\n"
193
+ f"**Experiment:** {experiment_name}\n"
194
+ f"**Epochs:** {epochs}\n"
195
+ f"**Learning Rate:** {learning_rate}\n\n"
196
+ f"**Final Metrics:**\n"
197
+ f"- Total Loss: {total_loss.item():.6f}\n"
198
+ f"- Energy MAE: {energy_mae:.6f} eV\n"
199
+ f"- Force MAE: {force_mae:.6f} eV/Å\n\n"
200
+ f"📊 Check the **Trackio Dashboard** for live loss curves and training dynamics!\n\n"
201
+ f"**Recent Training Log:**\n" + "\n".join(log_summary)
202
+ )
203
+
204
+ @mcp.tool()
205
+ def check_equivariance(rotation_degrees: float = 45.0) -> str:
206
+ """
207
+ Educational tool: Tests if the model is truly E(3)-equivariant.
208
+ Rotates the input molecule and checks if predicted forces rotate exactly with it.
209
+
210
+ Args:
211
+ rotation_degrees: Angle to rotate the molecule around Z-axis (default 45.0)
212
+
213
+ Returns:
214
+ Explanation of equivariance test results
215
+ """
216
+ if STATE["model"] is None:
217
+ return "⚠️ No model found! Run 'init_real_mace_model' first."
218
+
219
+ model = STATE["model"]
220
+ batch = STATE["batch"]
221
+
222
+ # Get original prediction
223
+ model.eval()
224
+ with torch.no_grad():
225
+ out_orig = model(batch.to_dict())
226
+ forces_orig = out_orig["forces"].clone()
227
+
228
+ # Apply rotation to positions
229
+ angle = np.radians(rotation_degrees)
230
+ rot_matrix = torch.tensor([
231
+ [np.cos(angle), -np.sin(angle), 0],
232
+ [np.sin(angle), np.cos(angle), 0],
233
+ [0, 0, 1]
234
+ ], dtype=torch.float32)
235
+
236
+ # Create rotated batch
237
+ batch_rot = batch.clone()
238
+ batch_rot.positions = torch.matmul(batch.positions, rot_matrix.T)
239
+
240
+ # Get prediction on rotated input
241
+ with torch.no_grad():
242
+ out_rot = model(batch_rot.to_dict())
243
+ forces_rot = out_rot["forces"]
244
+
245
+ # Manually rotate the original forces
246
+ forces_orig_rotated = torch.matmul(forces_orig, rot_matrix.T)
247
+
248
+ # Calculate equivariance error
249
+ equivariance_error = torch.mean(torch.abs(forces_rot - forces_orig_rotated)).item()
250
+
251
+ return (
252
+ f"🧪 **E(3)-Equivariance Test Results**\n\n"
253
+ f"**Test Setup:**\n"
254
+ f"- Molecule: Water (H₂O)\n"
255
+ f"- Rotation: {rotation_degrees}° around Z-axis\n\n"
256
+ f"**Results:**\n"
257
+ f"- Equivariance Error: {equivariance_error:.2e} eV/Å\n"
258
+ f"- Expected for perfect equivariance: ~1e-6 or lower\n\n"
259
+ f"**Interpretation:**\n"
260
+ f"{'✅ PASS: Model is equivariant!' if equivariance_error < 1e-4 else '⚠️ WARNING: High error detected'}\n\n"
261
+ f"This confirms that when you rotate the molecule, the predicted force vectors "
262
+ f"rotate **exactly** with it. Standard MLPs cannot achieve this without extensive "
263
+ f"data augmentation!"
264
  )
265
 
266
  if __name__ == "__main__":
267
  print("Starting MACE-MCP Server...")
268
+ # Use SSE transport for Hugging Face Spaces deployment
269
+ mcp.run(transport="sse")