File size: 6,272 Bytes
1ab83c8
 
c7d6ecc
1ab83c8
4af55fb
1ab83c8
38536d9
1ab83c8
4af55fb
 
1ab83c8
7c59449
c7d6ecc
 
 
4af55fb
 
c7d6ecc
1ab83c8
4eb56e0
c7d6ecc
 
 
 
 
 
 
 
7c59449
1ab83c8
c7d6ecc
 
 
4eb56e0
 
c7d6ecc
 
 
 
 
4eb56e0
c7d6ecc
 
4eb56e0
38536d9
3a15d3b
c7d6ecc
3a15d3b
38536d9
c7d6ecc
1ab83c8
38536d9
1ab83c8
4af55fb
46ef716
c7d6ecc
 
 
 
38536d9
89a8302
 
4af55fb
89a8302
38536d9
c7d6ecc
 
 
7c59449
70e1d1b
c7d6ecc
4eb56e0
1ab83c8
 
4af55fb
46ef716
7c59449
c7d6ecc
 
4af55fb
38536d9
c7d6ecc
1ab83c8
 
c7d6ecc
89a8302
 
1ab83c8
 
7c59449
 
46ef716
89a8302
70e1d1b
4af55fb
 
 
46ef716
4af55fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7c59449
4af55fb
70e1d1b
4af55fb
237ecb2
4af55fb
70e1d1b
4af55fb
 
46ef716
4af55fb
70e1d1b
4af55fb
70e1d1b
 
 
 
 
4af55fb
70e1d1b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4af55fb
70e1d1b
 
 
 
 
 
 
 
 
 
 
 
 
 
7c59449
 
4af55fb
70e1d1b
1ab83c8
70e1d1b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import torch
import numpy as np
import time
from fastmcp import FastMCP
import gradio as gr
from ase import Atoms
from ase.build import molecule

# --- 1. DEFINE MCP SERVER ---
mcp = FastMCP("Equivariant_Chem_Scout")

# Global State
STATE = {
    "model": None,
    "config": None,
    "batch": None,
    "logs": []
}

# --- HELPER FUNCTIONS ---
def get_mace_setup():
    try:
        from mace.models import ScaleShiftMACE
        from mace.data import AtomicData, Configuration
        from mace.tools import torch_geometric
        from e3nn import o3
        return ScaleShiftMACE, AtomicData, Configuration, torch_geometric, o3
    except ImportError:
        raise ImportError("MACE not installed.")

def create_dummy_batch(r_max=5.0):
    _, AtomicData, Configuration, torch_geometric, _ = get_mace_setup()
    mol = molecule("H2O")
    mol.info["energy"] = -14.0
    mol.arrays["forces"] = np.random.randn(3, 3) * 0.1
    config = Configuration(
        atomic_numbers=mol.get_atomic_numbers(),
        positions=mol.get_positions(),
        energy=mol.info["energy"],
        forces=mol.arrays["forces"],
        pbc=np.array([False, False, False]),
        cell=np.eye(3) * 10.0 
    )
    z_table = {1: 0, 8: 1}
    data_loader = torch_geometric.DataLoader(
        dataset=[AtomicData.from_config(config, z_table=z_table, cutoff=r_max)],
        batch_size=1,
        shuffle=False
    )
    return next(iter(data_loader))

# --- MCP TOOLS ---
@mcp.tool()
def init_mace_model(r_max: float = 5.0, max_ell: int = 2) -> str:
    """Initialize MACE model."""
    ScaleShiftMACE, _, _, _, o3 = get_mace_setup()
    batch = create_dummy_batch(r_max)
    STATE["batch"] = batch
    
    model_config = dict(
        r_max=r_max, num_bessel=8, num_polynomial_cutoff=5, max_ell=max_ell,
        interaction_cls="RealAgnosticInteractionBlock", num_interactions=2, num_elements=2,
        hidden_irreps=o3.Irreps("16x0e"), atomic_energies=np.array([-13.6, -10.0]),
        avg_num_neighbors=2, atomic_numbers=[1, 8]
    )
    try:
        model = ScaleShiftMACE(**model_config)
        STATE["model"] = model
        STATE["logs"] = []
        return f"βœ… MACE Model Initialized! (L_max={max_ell}, R_max={r_max}Γ…)"
    except Exception as e:
        return f"❌ Error: {str(e)}"

@mcp.tool()
def train_model(epochs: int = 10, learning_rate: float = 0.01) -> str:
    """Train model."""
    if STATE["model"] is None: return "⚠️ Run 'init_mace_model' first!"
    model = STATE["model"]
    batch = STATE["batch"]
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    model.train()
    
    for epoch in range(epochs):
        optimizer.zero_grad()
        out = model(batch.to_dict())
        loss = torch.mean((out["energy"] - batch.energy)**2) + 10.0 * torch.mean((out["forces"] - batch.forces)**2)
        loss.backward()
        optimizer.step()
        
        log = f"Epoch {epoch}: Loss={loss.item():.4f}"
        STATE["logs"].append(log)
        time.sleep(0.05)
        
    return f"πŸš€ Training Done! Final Loss: {loss.item():.4f}"

@mcp.tool()
def check_equivariance(rotation_degrees: float = 90.0) -> str:
    """Check equivariance."""
    if STATE["model"] is None: return "⚠️ No model found."
    model = STATE["model"]
    batch = STATE["batch"]
    model.eval()
    
    out_orig = model(batch.to_dict())
    f_orig = out_orig["forces"]
    
    angle = np.radians(rotation_degrees)
    R = torch.tensor([[np.cos(angle), -np.sin(angle), 0], [np.sin(angle), np.cos(angle), 0], [0,0,1]], dtype=torch.float32)
    
    batch_rot = batch.clone()
    batch_rot.positions = torch.matmul(batch.positions, R.T)
    out_rot = model(batch_rot.to_dict())
    
    f_orig_rot = torch.matmul(f_orig, R.T)
    error = torch.mean(torch.abs(out_rot["forces"] - f_orig_rot)).item()
    
    return f"πŸ§ͺ Equivariance Error: {error:.2e} (Pass: {error < 1e-4})"

# --- 2. GRADIO DASHBOARD ---
def get_latest_logs():
    if not STATE["logs"]: return "Waiting for training to start..."
    return "\n".join(STATE["logs"])

with gr.Blocks(title="Equivariant Chem Scout") as demo:
    gr.Markdown("# πŸ§ͺ Equivariant Chem Scout")
    gr.Markdown("### An MCP Server for Computational Chemistry Education")
    
    with gr.Tabs():
        with gr.Tab("πŸ“– How to Use"):
            gr.Markdown("""
            ### πŸ‘‹ Welcome!
            This Space hosts an **MCP Server** that lets AI Agents (like Claude) run real molecular dynamics code.
            
            #### Step 1: Connect Your Agent
            Add this configuration to your **Claude Desktop** config file:
            ```
            {
              "mcpServers": {
                "chem_scout": {
                  "url": "https://F555-Equivariant-Chem-Scout.hf.space/gradio_api/mcp/sse"
                }
              }
            }
            ```
            *(Restart Claude after saving)*
            
            #### Step 2: Chat with the Agent
            Once connected, try these prompts:
            - **"Initialize a MACE model with `max_ell=2` (equivariant) and `r_max=4.5`."**
            - **"Train the model for 20 epochs and show me the loss."**
            - **"Check if the model is equivariant by rotating the molecule 90 degrees."**
            """)
            
            with gr.Accordion("πŸ› οΈ What is MACE?", open=False):
                gr.Markdown("""
                **MACE (Multi-Atomic Cluster Expansion)** is a state-of-the-art machine learning potential for chemistry.
                - **Equivariance (`max_ell`):** Ensures that if you rotate a molecule, the forces rotate with it.
                - **Cutoff (`r_max`):** How far each atom "sees" its neighbors.
                """)

        with gr.Tab("πŸ“Š Live Logs"):
            gr.Markdown("### πŸ–₯️ Real-time Training Output")
            
            log_display = gr.TextArea(label="Server Logs", interactive=False, lines=20, max_lines=30)
            refresh_btn = gr.Button("πŸ”„ Refresh Logs Now")
            
            refresh_btn.click(fn=get_latest_logs, outputs=log_display)
            timer = gr.Timer(value=2.0)
            timer.tick(fn=get_latest_logs, outputs=log_display)

# --- 3. LAUNCH ---
if __name__ == "__main__":
    demo.launch(mcp_server=mcp)