F555's picture
Update app.py
46ef716 verified
import torch
import numpy as np
import time
from fastmcp import FastMCP
import gradio as gr
from ase import Atoms
from ase.build import molecule
# --- 1. DEFINE MCP SERVER ---
mcp = FastMCP("Equivariant_Chem_Scout")
# Global State
STATE = {
"model": None,
"config": None,
"batch": None,
"logs": []
}
# --- HELPER FUNCTIONS ---
def get_mace_setup():
try:
from mace.models import ScaleShiftMACE
from mace.data import AtomicData, Configuration
from mace.tools import torch_geometric
from e3nn import o3
return ScaleShiftMACE, AtomicData, Configuration, torch_geometric, o3
except ImportError:
raise ImportError("MACE not installed.")
def create_dummy_batch(r_max=5.0):
_, AtomicData, Configuration, torch_geometric, _ = get_mace_setup()
mol = molecule("H2O")
mol.info["energy"] = -14.0
mol.arrays["forces"] = np.random.randn(3, 3) * 0.1
config = Configuration(
atomic_numbers=mol.get_atomic_numbers(),
positions=mol.get_positions(),
energy=mol.info["energy"],
forces=mol.arrays["forces"],
pbc=np.array([False, False, False]),
cell=np.eye(3) * 10.0
)
z_table = {1: 0, 8: 1}
data_loader = torch_geometric.DataLoader(
dataset=[AtomicData.from_config(config, z_table=z_table, cutoff=r_max)],
batch_size=1,
shuffle=False
)
return next(iter(data_loader))
# --- MCP TOOLS ---
@mcp.tool()
def init_mace_model(r_max: float = 5.0, max_ell: int = 2) -> str:
"""Initialize MACE model."""
ScaleShiftMACE, _, _, _, o3 = get_mace_setup()
batch = create_dummy_batch(r_max)
STATE["batch"] = batch
model_config = dict(
r_max=r_max, num_bessel=8, num_polynomial_cutoff=5, max_ell=max_ell,
interaction_cls="RealAgnosticInteractionBlock", num_interactions=2, num_elements=2,
hidden_irreps=o3.Irreps("16x0e"), atomic_energies=np.array([-13.6, -10.0]),
avg_num_neighbors=2, atomic_numbers=[1, 8]
)
try:
model = ScaleShiftMACE(**model_config)
STATE["model"] = model
STATE["logs"] = []
return f"βœ… MACE Model Initialized! (L_max={max_ell}, R_max={r_max}Γ…)"
except Exception as e:
return f"❌ Error: {str(e)}"
@mcp.tool()
def train_model(epochs: int = 10, learning_rate: float = 0.01) -> str:
"""Train model."""
if STATE["model"] is None: return "⚠️ Run 'init_mace_model' first!"
model = STATE["model"]
batch = STATE["batch"]
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
model.train()
for epoch in range(epochs):
optimizer.zero_grad()
out = model(batch.to_dict())
loss = torch.mean((out["energy"] - batch.energy)**2) + 10.0 * torch.mean((out["forces"] - batch.forces)**2)
loss.backward()
optimizer.step()
log = f"Epoch {epoch}: Loss={loss.item():.4f}"
STATE["logs"].append(log)
time.sleep(0.05)
return f"πŸš€ Training Done! Final Loss: {loss.item():.4f}"
@mcp.tool()
def check_equivariance(rotation_degrees: float = 90.0) -> str:
"""Check equivariance."""
if STATE["model"] is None: return "⚠️ No model found."
model = STATE["model"]
batch = STATE["batch"]
model.eval()
out_orig = model(batch.to_dict())
f_orig = out_orig["forces"]
angle = np.radians(rotation_degrees)
R = torch.tensor([[np.cos(angle), -np.sin(angle), 0], [np.sin(angle), np.cos(angle), 0], [0,0,1]], dtype=torch.float32)
batch_rot = batch.clone()
batch_rot.positions = torch.matmul(batch.positions, R.T)
out_rot = model(batch_rot.to_dict())
f_orig_rot = torch.matmul(f_orig, R.T)
error = torch.mean(torch.abs(out_rot["forces"] - f_orig_rot)).item()
return f"πŸ§ͺ Equivariance Error: {error:.2e} (Pass: {error < 1e-4})"
# --- 2. GRADIO DASHBOARD ---
def get_latest_logs():
if not STATE["logs"]: return "Waiting for training to start..."
return "\n".join(STATE["logs"])
with gr.Blocks(title="Equivariant Chem Scout") as demo:
gr.Markdown("# πŸ§ͺ Equivariant Chem Scout")
gr.Markdown("### An MCP Server for Computational Chemistry Education")
with gr.Tabs():
with gr.Tab("πŸ“– How to Use"):
gr.Markdown("""
### πŸ‘‹ Welcome!
This Space hosts an **MCP Server** that lets AI Agents (like Claude) run real molecular dynamics code.
#### Step 1: Connect Your Agent
Add this configuration to your **Claude Desktop** config file:
```
{
"mcpServers": {
"chem_scout": {
"url": "https://F555-Equivariant-Chem-Scout.hf.space/gradio_api/mcp/sse"
}
}
}
```
*(Restart Claude after saving)*
#### Step 2: Chat with the Agent
Once connected, try these prompts:
- **"Initialize a MACE model with `max_ell=2` (equivariant) and `r_max=4.5`."**
- **"Train the model for 20 epochs and show me the loss."**
- **"Check if the model is equivariant by rotating the molecule 90 degrees."**
""")
with gr.Accordion("πŸ› οΈ What is MACE?", open=False):
gr.Markdown("""
**MACE (Multi-Atomic Cluster Expansion)** is a state-of-the-art machine learning potential for chemistry.
- **Equivariance (`max_ell`):** Ensures that if you rotate a molecule, the forces rotate with it.
- **Cutoff (`r_max`):** How far each atom "sees" its neighbors.
""")
with gr.Tab("πŸ“Š Live Logs"):
gr.Markdown("### πŸ–₯️ Real-time Training Output")
log_display = gr.TextArea(label="Server Logs", interactive=False, lines=20, max_lines=30)
refresh_btn = gr.Button("πŸ”„ Refresh Logs Now")
refresh_btn.click(fn=get_latest_logs, outputs=log_display)
timer = gr.Timer(value=2.0)
timer.tick(fn=get_latest_logs, outputs=log_display)
# --- 3. LAUNCH ---
if __name__ == "__main__":
demo.launch(mcp_server=mcp)