Spaces:
Sleeping
Sleeping
| import torch | |
| import numpy as np | |
| import time | |
| from fastmcp import FastMCP | |
| import gradio as gr | |
| from ase import Atoms | |
| from ase.build import molecule | |
| # --- 1. DEFINE MCP SERVER --- | |
| mcp = FastMCP("Equivariant_Chem_Scout") | |
| # Global State | |
| STATE = { | |
| "model": None, | |
| "config": None, | |
| "batch": None, | |
| "logs": [] | |
| } | |
| # --- HELPER FUNCTIONS --- | |
| def get_mace_setup(): | |
| try: | |
| from mace.models import ScaleShiftMACE | |
| from mace.data import AtomicData, Configuration | |
| from mace.tools import torch_geometric | |
| from e3nn import o3 | |
| return ScaleShiftMACE, AtomicData, Configuration, torch_geometric, o3 | |
| except ImportError: | |
| raise ImportError("MACE not installed.") | |
| def create_dummy_batch(r_max=5.0): | |
| _, AtomicData, Configuration, torch_geometric, _ = get_mace_setup() | |
| mol = molecule("H2O") | |
| mol.info["energy"] = -14.0 | |
| mol.arrays["forces"] = np.random.randn(3, 3) * 0.1 | |
| config = Configuration( | |
| atomic_numbers=mol.get_atomic_numbers(), | |
| positions=mol.get_positions(), | |
| energy=mol.info["energy"], | |
| forces=mol.arrays["forces"], | |
| pbc=np.array([False, False, False]), | |
| cell=np.eye(3) * 10.0 | |
| ) | |
| z_table = {1: 0, 8: 1} | |
| data_loader = torch_geometric.DataLoader( | |
| dataset=[AtomicData.from_config(config, z_table=z_table, cutoff=r_max)], | |
| batch_size=1, | |
| shuffle=False | |
| ) | |
| return next(iter(data_loader)) | |
| # --- MCP TOOLS --- | |
| def init_mace_model(r_max: float = 5.0, max_ell: int = 2) -> str: | |
| """Initialize MACE model.""" | |
| ScaleShiftMACE, _, _, _, o3 = get_mace_setup() | |
| batch = create_dummy_batch(r_max) | |
| STATE["batch"] = batch | |
| model_config = dict( | |
| r_max=r_max, num_bessel=8, num_polynomial_cutoff=5, max_ell=max_ell, | |
| interaction_cls="RealAgnosticInteractionBlock", num_interactions=2, num_elements=2, | |
| hidden_irreps=o3.Irreps("16x0e"), atomic_energies=np.array([-13.6, -10.0]), | |
| avg_num_neighbors=2, atomic_numbers=[1, 8] | |
| ) | |
| try: | |
| model = ScaleShiftMACE(**model_config) | |
| STATE["model"] = model | |
| STATE["logs"] = [] | |
| return f"β MACE Model Initialized! (L_max={max_ell}, R_max={r_max}Γ )" | |
| except Exception as e: | |
| return f"β Error: {str(e)}" | |
| def train_model(epochs: int = 10, learning_rate: float = 0.01) -> str: | |
| """Train model.""" | |
| if STATE["model"] is None: return "β οΈ Run 'init_mace_model' first!" | |
| model = STATE["model"] | |
| batch = STATE["batch"] | |
| optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) | |
| model.train() | |
| for epoch in range(epochs): | |
| optimizer.zero_grad() | |
| out = model(batch.to_dict()) | |
| loss = torch.mean((out["energy"] - batch.energy)**2) + 10.0 * torch.mean((out["forces"] - batch.forces)**2) | |
| loss.backward() | |
| optimizer.step() | |
| log = f"Epoch {epoch}: Loss={loss.item():.4f}" | |
| STATE["logs"].append(log) | |
| time.sleep(0.05) | |
| return f"π Training Done! Final Loss: {loss.item():.4f}" | |
| def check_equivariance(rotation_degrees: float = 90.0) -> str: | |
| """Check equivariance.""" | |
| if STATE["model"] is None: return "β οΈ No model found." | |
| model = STATE["model"] | |
| batch = STATE["batch"] | |
| model.eval() | |
| out_orig = model(batch.to_dict()) | |
| f_orig = out_orig["forces"] | |
| angle = np.radians(rotation_degrees) | |
| R = torch.tensor([[np.cos(angle), -np.sin(angle), 0], [np.sin(angle), np.cos(angle), 0], [0,0,1]], dtype=torch.float32) | |
| batch_rot = batch.clone() | |
| batch_rot.positions = torch.matmul(batch.positions, R.T) | |
| out_rot = model(batch_rot.to_dict()) | |
| f_orig_rot = torch.matmul(f_orig, R.T) | |
| error = torch.mean(torch.abs(out_rot["forces"] - f_orig_rot)).item() | |
| return f"π§ͺ Equivariance Error: {error:.2e} (Pass: {error < 1e-4})" | |
| # --- 2. GRADIO DASHBOARD --- | |
| def get_latest_logs(): | |
| if not STATE["logs"]: return "Waiting for training to start..." | |
| return "\n".join(STATE["logs"]) | |
| with gr.Blocks(title="Equivariant Chem Scout") as demo: | |
| gr.Markdown("# π§ͺ Equivariant Chem Scout") | |
| gr.Markdown("### An MCP Server for Computational Chemistry Education") | |
| with gr.Tabs(): | |
| with gr.Tab("π How to Use"): | |
| gr.Markdown(""" | |
| ### π Welcome! | |
| This Space hosts an **MCP Server** that lets AI Agents (like Claude) run real molecular dynamics code. | |
| #### Step 1: Connect Your Agent | |
| Add this configuration to your **Claude Desktop** config file: | |
| ``` | |
| { | |
| "mcpServers": { | |
| "chem_scout": { | |
| "url": "https://F555-Equivariant-Chem-Scout.hf.space/gradio_api/mcp/sse" | |
| } | |
| } | |
| } | |
| ``` | |
| *(Restart Claude after saving)* | |
| #### Step 2: Chat with the Agent | |
| Once connected, try these prompts: | |
| - **"Initialize a MACE model with `max_ell=2` (equivariant) and `r_max=4.5`."** | |
| - **"Train the model for 20 epochs and show me the loss."** | |
| - **"Check if the model is equivariant by rotating the molecule 90 degrees."** | |
| """) | |
| with gr.Accordion("π οΈ What is MACE?", open=False): | |
| gr.Markdown(""" | |
| **MACE (Multi-Atomic Cluster Expansion)** is a state-of-the-art machine learning potential for chemistry. | |
| - **Equivariance (`max_ell`):** Ensures that if you rotate a molecule, the forces rotate with it. | |
| - **Cutoff (`r_max`):** How far each atom "sees" its neighbors. | |
| """) | |
| with gr.Tab("π Live Logs"): | |
| gr.Markdown("### π₯οΈ Real-time Training Output") | |
| log_display = gr.TextArea(label="Server Logs", interactive=False, lines=20, max_lines=30) | |
| refresh_btn = gr.Button("π Refresh Logs Now") | |
| refresh_btn.click(fn=get_latest_logs, outputs=log_display) | |
| timer = gr.Timer(value=2.0) | |
| timer.tick(fn=get_latest_logs, outputs=log_display) | |
| # --- 3. LAUNCH --- | |
| if __name__ == "__main__": | |
| demo.launch(mcp_server=mcp) | |