Spaces:
Sleeping
Sleeping
File size: 6,272 Bytes
1ab83c8 c7d6ecc 1ab83c8 4af55fb 1ab83c8 38536d9 1ab83c8 4af55fb 1ab83c8 7c59449 c7d6ecc 4af55fb c7d6ecc 1ab83c8 4eb56e0 c7d6ecc 7c59449 1ab83c8 c7d6ecc 4eb56e0 c7d6ecc 4eb56e0 c7d6ecc 4eb56e0 38536d9 3a15d3b c7d6ecc 3a15d3b 38536d9 c7d6ecc 1ab83c8 38536d9 1ab83c8 4af55fb 46ef716 c7d6ecc 38536d9 89a8302 4af55fb 89a8302 38536d9 c7d6ecc 7c59449 70e1d1b c7d6ecc 4eb56e0 1ab83c8 4af55fb 46ef716 7c59449 c7d6ecc 4af55fb 38536d9 c7d6ecc 1ab83c8 c7d6ecc 89a8302 1ab83c8 7c59449 46ef716 89a8302 70e1d1b 4af55fb 46ef716 4af55fb 7c59449 4af55fb 70e1d1b 4af55fb 237ecb2 4af55fb 70e1d1b 4af55fb 46ef716 4af55fb 70e1d1b 4af55fb 70e1d1b 4af55fb 70e1d1b 4af55fb 70e1d1b 7c59449 4af55fb 70e1d1b 1ab83c8 70e1d1b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
import torch
import numpy as np
import time
from fastmcp import FastMCP
import gradio as gr
from ase import Atoms
from ase.build import molecule
# --- 1. DEFINE MCP SERVER ---
mcp = FastMCP("Equivariant_Chem_Scout")
# Global State
STATE = {
"model": None,
"config": None,
"batch": None,
"logs": []
}
# --- HELPER FUNCTIONS ---
def get_mace_setup():
try:
from mace.models import ScaleShiftMACE
from mace.data import AtomicData, Configuration
from mace.tools import torch_geometric
from e3nn import o3
return ScaleShiftMACE, AtomicData, Configuration, torch_geometric, o3
except ImportError:
raise ImportError("MACE not installed.")
def create_dummy_batch(r_max=5.0):
_, AtomicData, Configuration, torch_geometric, _ = get_mace_setup()
mol = molecule("H2O")
mol.info["energy"] = -14.0
mol.arrays["forces"] = np.random.randn(3, 3) * 0.1
config = Configuration(
atomic_numbers=mol.get_atomic_numbers(),
positions=mol.get_positions(),
energy=mol.info["energy"],
forces=mol.arrays["forces"],
pbc=np.array([False, False, False]),
cell=np.eye(3) * 10.0
)
z_table = {1: 0, 8: 1}
data_loader = torch_geometric.DataLoader(
dataset=[AtomicData.from_config(config, z_table=z_table, cutoff=r_max)],
batch_size=1,
shuffle=False
)
return next(iter(data_loader))
# --- MCP TOOLS ---
@mcp.tool()
def init_mace_model(r_max: float = 5.0, max_ell: int = 2) -> str:
"""Initialize MACE model."""
ScaleShiftMACE, _, _, _, o3 = get_mace_setup()
batch = create_dummy_batch(r_max)
STATE["batch"] = batch
model_config = dict(
r_max=r_max, num_bessel=8, num_polynomial_cutoff=5, max_ell=max_ell,
interaction_cls="RealAgnosticInteractionBlock", num_interactions=2, num_elements=2,
hidden_irreps=o3.Irreps("16x0e"), atomic_energies=np.array([-13.6, -10.0]),
avg_num_neighbors=2, atomic_numbers=[1, 8]
)
try:
model = ScaleShiftMACE(**model_config)
STATE["model"] = model
STATE["logs"] = []
return f"β
MACE Model Initialized! (L_max={max_ell}, R_max={r_max}Γ
)"
except Exception as e:
return f"β Error: {str(e)}"
@mcp.tool()
def train_model(epochs: int = 10, learning_rate: float = 0.01) -> str:
"""Train model."""
if STATE["model"] is None: return "β οΈ Run 'init_mace_model' first!"
model = STATE["model"]
batch = STATE["batch"]
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
model.train()
for epoch in range(epochs):
optimizer.zero_grad()
out = model(batch.to_dict())
loss = torch.mean((out["energy"] - batch.energy)**2) + 10.0 * torch.mean((out["forces"] - batch.forces)**2)
loss.backward()
optimizer.step()
log = f"Epoch {epoch}: Loss={loss.item():.4f}"
STATE["logs"].append(log)
time.sleep(0.05)
return f"π Training Done! Final Loss: {loss.item():.4f}"
@mcp.tool()
def check_equivariance(rotation_degrees: float = 90.0) -> str:
"""Check equivariance."""
if STATE["model"] is None: return "β οΈ No model found."
model = STATE["model"]
batch = STATE["batch"]
model.eval()
out_orig = model(batch.to_dict())
f_orig = out_orig["forces"]
angle = np.radians(rotation_degrees)
R = torch.tensor([[np.cos(angle), -np.sin(angle), 0], [np.sin(angle), np.cos(angle), 0], [0,0,1]], dtype=torch.float32)
batch_rot = batch.clone()
batch_rot.positions = torch.matmul(batch.positions, R.T)
out_rot = model(batch_rot.to_dict())
f_orig_rot = torch.matmul(f_orig, R.T)
error = torch.mean(torch.abs(out_rot["forces"] - f_orig_rot)).item()
return f"π§ͺ Equivariance Error: {error:.2e} (Pass: {error < 1e-4})"
# --- 2. GRADIO DASHBOARD ---
def get_latest_logs():
if not STATE["logs"]: return "Waiting for training to start..."
return "\n".join(STATE["logs"])
with gr.Blocks(title="Equivariant Chem Scout") as demo:
gr.Markdown("# π§ͺ Equivariant Chem Scout")
gr.Markdown("### An MCP Server for Computational Chemistry Education")
with gr.Tabs():
with gr.Tab("π How to Use"):
gr.Markdown("""
### π Welcome!
This Space hosts an **MCP Server** that lets AI Agents (like Claude) run real molecular dynamics code.
#### Step 1: Connect Your Agent
Add this configuration to your **Claude Desktop** config file:
```
{
"mcpServers": {
"chem_scout": {
"url": "https://F555-Equivariant-Chem-Scout.hf.space/gradio_api/mcp/sse"
}
}
}
```
*(Restart Claude after saving)*
#### Step 2: Chat with the Agent
Once connected, try these prompts:
- **"Initialize a MACE model with `max_ell=2` (equivariant) and `r_max=4.5`."**
- **"Train the model for 20 epochs and show me the loss."**
- **"Check if the model is equivariant by rotating the molecule 90 degrees."**
""")
with gr.Accordion("π οΈ What is MACE?", open=False):
gr.Markdown("""
**MACE (Multi-Atomic Cluster Expansion)** is a state-of-the-art machine learning potential for chemistry.
- **Equivariance (`max_ell`):** Ensures that if you rotate a molecule, the forces rotate with it.
- **Cutoff (`r_max`):** How far each atom "sees" its neighbors.
""")
with gr.Tab("π Live Logs"):
gr.Markdown("### π₯οΈ Real-time Training Output")
log_display = gr.TextArea(label="Server Logs", interactive=False, lines=20, max_lines=30)
refresh_btn = gr.Button("π Refresh Logs Now")
refresh_btn.click(fn=get_latest_logs, outputs=log_display)
timer = gr.Timer(value=2.0)
timer.tick(fn=get_latest_logs, outputs=log_display)
# --- 3. LAUNCH ---
if __name__ == "__main__":
demo.launch(mcp_server=mcp)
|