| | |
| | from __future__ import annotations |
| |
|
| | import csv, re, json |
| | from dataclasses import dataclass |
| | from pathlib import Path |
| | from typing import Dict, Optional, Tuple, Any, List |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | import joblib |
| | import xgboost as xgb |
| |
|
| | from transformers import EsmModel, EsmTokenizer, AutoModelForMaskedLM |
| | from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer |
| |
|
| |
|
| | |
| | |
| | |
| | @dataclass(frozen=True) |
| | class BestRow: |
| | property_key: str |
| | best_wt: Optional[str] |
| | best_smiles: Optional[str] |
| | task_type: str |
| | thr_wt: Optional[float] |
| | thr_smiles: Optional[float] |
| |
|
| |
|
| | def _clean(s: str) -> str: |
| | return (s or "").strip() |
| |
|
| | def _none_if_dash(s: str) -> Optional[str]: |
| | s = _clean(s) |
| | if s in {"", "-", "—", "NA", "N/A"}: |
| | return None |
| | return s |
| |
|
| | def _float_or_none(s: str) -> Optional[float]: |
| | s = _clean(s) |
| | if s in {"", "-", "—", "NA", "N/A"}: |
| | return None |
| | return float(s) |
| |
|
| | def normalize_property_key(name: str) -> str: |
| | n = name.strip().lower() |
| | n = re.sub(r"\s*\(.*?\)\s*", "", n) |
| | n = n.replace("-", "_").replace(" ", "_") |
| | if "permeability" in n and "pampa" not in n and "caco" not in n: |
| | return "permeability_penetrance" |
| | if n == "binding_affinity": |
| | return "binding_affinity" |
| | if n == "halflife": |
| | return "half_life" |
| | if n == "non_fouling": |
| | return "nf" |
| | return n |
| |
|
| | def read_best_manifest_csv(path: str | Path) -> Dict[str, BestRow]: |
| | """ |
| | Properties, Best_Model_WT, Best_Model_SMILES, Type, Threshold_WT, Threshold_SMILES, |
| | Hemolysis, SVM, SGB, Classifier, 0.2801, 0.2223, |
| | """ |
| | p = Path(path) |
| | out: Dict[str, BestRow] = {} |
| |
|
| | with p.open("r", newline="") as f: |
| | reader = csv.reader(f) |
| | header = None |
| | for raw in reader: |
| | if not raw or all(_clean(x) == "" for x in raw): |
| | continue |
| | while raw and _clean(raw[-1]) == "": |
| | raw = raw[:-1] |
| |
|
| | if header is None: |
| | header = [h.strip() for h in raw] |
| | continue |
| |
|
| | if len(raw) < len(header): |
| | raw = raw + [""] * (len(header) - len(raw)) |
| | rec = dict(zip(header, raw)) |
| |
|
| | prop_raw = _clean(rec.get("Properties", "")) |
| | if not prop_raw: |
| | continue |
| | prop_key = normalize_property_key(prop_raw) |
| |
|
| | row = BestRow( |
| | property_key=prop_key, |
| | best_wt=_none_if_dash(rec.get("Best_Model_WT", "")), |
| | best_smiles=_none_if_dash(rec.get("Best_Model_SMILES", "")), |
| | task_type=_clean(rec.get("Type", "Classifier")), |
| | thr_wt=_float_or_none(rec.get("Threshold_WT", "")), |
| | thr_smiles=_float_or_none(rec.get("Threshold_SMILES", "")), |
| | ) |
| | out[prop_key] = row |
| |
|
| | return out |
| |
|
| |
|
| | MODEL_ALIAS = { |
| | "SVM": "svm_gpu", |
| | "SVR": "svr", |
| | "ENET": "enet_gpu", |
| | "CNN": "cnn", |
| | "MLP": "mlp", |
| | "TRANSFORMER": "transformer", |
| | "XGB": "xgb", |
| | "XGB_REG": "xgb_reg", |
| | "POOLED": "pooled", |
| | "UNPOOLED": "unpooled" |
| | } |
| | def canon_model(label: Optional[str]) -> Optional[str]: |
| | if label is None: |
| | return None |
| | k = label.strip().upper() |
| | return MODEL_ALIAS.get(k, label.strip().lower()) |
| |
|
| |
|
| | |
| | |
| | |
| | def find_best_artifact(model_dir: Path) -> Path: |
| | for pat in ["best_model.json", "best_model.pt", "best_model*.joblib"]: |
| | hits = sorted(model_dir.glob(pat)) |
| | if hits: |
| | return hits[0] |
| | raise FileNotFoundError(f"No best_model artifact found in {model_dir}") |
| |
|
| | def load_artifact(model_dir: Path, device: torch.device) -> Tuple[str, Any, Path]: |
| | art = find_best_artifact(model_dir) |
| |
|
| | if art.suffix == ".json": |
| | booster = xgb.Booster() |
| | print(str(art)) |
| | booster.load_model(str(art)) |
| | return "xgb", booster, art |
| |
|
| | if art.suffix == ".joblib": |
| | obj = joblib.load(art) |
| | return "joblib", obj, art |
| |
|
| | if art.suffix == ".pt": |
| | ckpt = torch.load(art, map_location=device, weights_only=False) |
| | return "torch_ckpt", ckpt, art |
| |
|
| | raise ValueError(f"Unknown artifact type: {art}") |
| |
|
| |
|
| | |
| | |
| | |
| | class MaskedMeanPool(nn.Module): |
| | def forward(self, X, M): |
| | Mf = M.unsqueeze(-1).float() |
| | denom = Mf.sum(dim=1).clamp(min=1.0) |
| | return (X * Mf).sum(dim=1) / denom |
| |
|
| | class MLPHead(nn.Module): |
| | def __init__(self, in_dim, hidden=512, dropout=0.1): |
| | super().__init__() |
| | self.pool = MaskedMeanPool() |
| | self.net = nn.Sequential( |
| | nn.Linear(in_dim, hidden), |
| | nn.GELU(), |
| | nn.Dropout(dropout), |
| | nn.Linear(hidden, 1), |
| | ) |
| | def forward(self, X, M): |
| | z = self.pool(X, M) |
| | return self.net(z).squeeze(-1) |
| |
|
| | class CNNHead(nn.Module): |
| | def __init__(self, in_ch, c=256, k=5, layers=2, dropout=0.1): |
| | super().__init__() |
| | blocks = [] |
| | ch = in_ch |
| | for _ in range(layers): |
| | blocks += [nn.Conv1d(ch, c, kernel_size=k, padding=k//2), |
| | nn.GELU(), |
| | nn.Dropout(dropout)] |
| | ch = c |
| | self.conv = nn.Sequential(*blocks) |
| | self.head = nn.Linear(c, 1) |
| |
|
| | def forward(self, X, M): |
| | Xc = X.transpose(1, 2) |
| | Y = self.conv(Xc).transpose(1, 2) |
| | Mf = M.unsqueeze(-1).float() |
| | denom = Mf.sum(dim=1).clamp(min=1.0) |
| | pooled = (Y * Mf).sum(dim=1) / denom |
| | return self.head(pooled).squeeze(-1) |
| |
|
| | class TransformerHead(nn.Module): |
| | def __init__(self, in_dim, d_model=256, nhead=8, layers=2, ff=512, dropout=0.1): |
| | super().__init__() |
| | self.proj = nn.Linear(in_dim, d_model) |
| | enc_layer = nn.TransformerEncoderLayer( |
| | d_model=d_model, nhead=nhead, dim_feedforward=ff, |
| | dropout=dropout, batch_first=True, activation="gelu" |
| | ) |
| | self.enc = nn.TransformerEncoder(enc_layer, num_layers=layers) |
| | self.head = nn.Linear(d_model, 1) |
| |
|
| | def forward(self, X, M): |
| | pad_mask = ~M |
| | Z = self.proj(X) |
| | Z = self.enc(Z, src_key_padding_mask=pad_mask) |
| | Mf = M.unsqueeze(-1).float() |
| | denom = Mf.sum(dim=1).clamp(min=1.0) |
| | pooled = (Z * Mf).sum(dim=1) / denom |
| | return self.head(pooled).squeeze(-1) |
| |
|
| | def _infer_in_dim_from_sd(sd: dict, model_name: str) -> int: |
| | if model_name == "mlp": |
| | return int(sd["net.0.weight"].shape[1]) |
| | if model_name == "cnn": |
| | return int(sd["conv.0.weight"].shape[1]) |
| | if model_name == "transformer": |
| | return int(sd["proj.weight"].shape[1]) |
| | raise ValueError(model_name) |
| |
|
| | def build_torch_model_from_ckpt(model_name: str, ckpt: dict, device: torch.device) -> nn.Module: |
| | params = ckpt["best_params"] |
| | sd = ckpt["state_dict"] |
| | in_dim = int(ckpt.get("in_dim", _infer_in_dim_from_sd(sd, model_name))) |
| | dropout = float(params.get("dropout", 0.1)) |
| |
|
| | if model_name == "mlp": |
| | model = MLPHead(in_dim=in_dim, hidden=int(params["hidden"]), dropout=dropout) |
| | elif model_name == "cnn": |
| | model = CNNHead(in_ch=in_dim, c=int(params["channels"]), k=int(params["kernel"]), |
| | layers=int(params["layers"]), dropout=dropout) |
| | elif model_name == "transformer": |
| | model = TransformerHead(in_dim=in_dim, d_model=int(params["d_model"]), nhead=int(params["nhead"]), |
| | layers=int(params["layers"]), ff=int(params["ff"]), dropout=dropout) |
| | else: |
| | raise ValueError(f"Unknown NN model_name={model_name}") |
| |
|
| | model.load_state_dict(sd) |
| | model.to(device) |
| | model.eval() |
| | return model |
| |
|
| |
|
| | |
| | |
| | |
| | def affinity_to_class(y: float) -> int: |
| | |
| | if y >= 9.0: return 0 |
| | if y < 7.0: return 2 |
| | return 1 |
| |
|
| | class CrossAttnPooled(nn.Module): |
| | def __init__(self, Ht, Hb, hidden=512, n_heads=8, n_layers=3, dropout=0.1): |
| | super().__init__() |
| | self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden)) |
| | self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden)) |
| |
|
| | self.layers = nn.ModuleList([]) |
| | for _ in range(n_layers): |
| | self.layers.append(nn.ModuleDict({ |
| | "attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False), |
| | "attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False), |
| | "n1t": nn.LayerNorm(hidden), |
| | "n2t": nn.LayerNorm(hidden), |
| | "n1b": nn.LayerNorm(hidden), |
| | "n2b": nn.LayerNorm(hidden), |
| | "fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)), |
| | "ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)), |
| | })) |
| |
|
| | self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout)) |
| | self.reg = nn.Linear(hidden, 1) |
| | self.cls = nn.Linear(hidden, 3) |
| |
|
| | def forward(self, t_vec, b_vec): |
| | t = self.t_proj(t_vec).unsqueeze(0) |
| | b = self.b_proj(b_vec).unsqueeze(0) |
| | for L in self.layers: |
| | t_attn, _ = L["attn_tb"](t, b, b) |
| | t = L["n1t"]((t + t_attn).transpose(0,1)).transpose(0,1) |
| | t = L["n2t"]((t + L["fft"](t)).transpose(0,1)).transpose(0,1) |
| |
|
| | b_attn, _ = L["attn_bt"](b, t, t) |
| | b = L["n1b"]((b + b_attn).transpose(0,1)).transpose(0,1) |
| | b = L["n2b"]((b + L["ffb"](b)).transpose(0,1)).transpose(0,1) |
| |
|
| | z = torch.cat([t[0], b[0]], dim=-1) |
| | h = self.shared(z) |
| | return self.reg(h).squeeze(-1), self.cls(h) |
| |
|
| | class CrossAttnUnpooled(nn.Module): |
| | def __init__(self, Ht, Hb, hidden=512, n_heads=8, n_layers=3, dropout=0.1): |
| | super().__init__() |
| | self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden)) |
| | self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden)) |
| |
|
| | self.layers = nn.ModuleList([]) |
| | for _ in range(n_layers): |
| | self.layers.append(nn.ModuleDict({ |
| | "attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True), |
| | "attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True), |
| | "n1t": nn.LayerNorm(hidden), |
| | "n2t": nn.LayerNorm(hidden), |
| | "n1b": nn.LayerNorm(hidden), |
| | "n2b": nn.LayerNorm(hidden), |
| | "fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)), |
| | "ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)), |
| | })) |
| |
|
| | self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout)) |
| | self.reg = nn.Linear(hidden, 1) |
| | self.cls = nn.Linear(hidden, 3) |
| |
|
| | def _masked_mean(self, X, M): |
| | Mf = M.unsqueeze(-1).float() |
| | denom = Mf.sum(dim=1).clamp(min=1.0) |
| | return (X * Mf).sum(dim=1) / denom |
| |
|
| | def forward(self, T, Mt, B, Mb): |
| | T = self.t_proj(T) |
| | Bx = self.b_proj(B) |
| | kp_t = ~Mt |
| | kp_b = ~Mb |
| |
|
| | for L in self.layers: |
| | T_attn, _ = L["attn_tb"](T, Bx, Bx, key_padding_mask=kp_b) |
| | T = L["n1t"](T + T_attn) |
| | T = L["n2t"](T + L["fft"](T)) |
| |
|
| | B_attn, _ = L["attn_bt"](Bx, T, T, key_padding_mask=kp_t) |
| | Bx = L["n1b"](Bx + B_attn) |
| | Bx = L["n2b"](Bx + L["ffb"](Bx)) |
| |
|
| | t_pool = self._masked_mean(T, Mt) |
| | b_pool = self._masked_mean(Bx, Mb) |
| | z = torch.cat([t_pool, b_pool], dim=-1) |
| | h = self.shared(z) |
| | return self.reg(h).squeeze(-1), self.cls(h) |
| |
|
| | def load_binding_model(best_model_pt: Path, pooled_or_unpooled: str, device: torch.device) -> nn.Module: |
| | ckpt = torch.load(best_model_pt, map_location=device, weights_only=False) |
| | params = ckpt["best_params"] |
| | sd = ckpt["state_dict"] |
| |
|
| | |
| | Ht = int(sd["t_proj.0.weight"].shape[1]) |
| | Hb = int(sd["b_proj.0.weight"].shape[1]) |
| |
|
| | common = dict( |
| | Ht=Ht, Hb=Hb, |
| | hidden=int(params["hidden_dim"]), |
| | n_heads=int(params["n_heads"]), |
| | n_layers=int(params["n_layers"]), |
| | dropout=float(params["dropout"]), |
| | ) |
| |
|
| | if pooled_or_unpooled == "pooled": |
| | model = CrossAttnPooled(**common) |
| | elif pooled_or_unpooled == "unpooled": |
| | model = CrossAttnUnpooled(**common) |
| | else: |
| | raise ValueError(pooled_or_unpooled) |
| |
|
| | model.load_state_dict(sd) |
| | model.to(device).eval() |
| | return model |
| |
|
| |
|
| | |
| | |
| | |
| | def _safe_isin(ids: torch.Tensor, test_ids: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Pytorch patch |
| | """ |
| | if hasattr(torch, "isin"): |
| | return torch.isin(ids, test_ids) |
| | |
| | |
| | return (ids.unsqueeze(-1) == test_ids.view(1, 1, -1)).any(dim=-1) |
| | |
| | class SMILESEmbedder: |
| | """ |
| | PeptideCLM RoFormer embeddings for SMILES. |
| | - pooled(): mean over tokens where attention_mask==1 AND token_id not in SPECIAL_IDS |
| | - unpooled(): returns token embeddings filtered to valid tokens (specials removed), |
| | plus a 1-mask of length Li (since already filtered). |
| | """ |
| | def __init__( |
| | self, |
| | device: torch.device, |
| | vocab_path: str, |
| | splits_path: str, |
| | clm_name: str = "aaronfeller/PeptideCLM-23M-all", |
| | max_len: int = 512, |
| | use_cache: bool = True, |
| | ): |
| | self.device = device |
| | self.max_len = max_len |
| | self.use_cache = use_cache |
| |
|
| | self.tokenizer = SMILES_SPE_Tokenizer(vocab_path, splits_path) |
| | self.model = AutoModelForMaskedLM.from_pretrained(clm_name).roformer.to(device).eval() |
| |
|
| | self.special_ids = self._get_special_ids(self.tokenizer) |
| | self.special_ids_t = (torch.tensor(self.special_ids, device=device, dtype=torch.long) |
| | if len(self.special_ids) else None) |
| |
|
| | self._cache_pooled: Dict[str, torch.Tensor] = {} |
| | self._cache_unpooled: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {} |
| |
|
| | @staticmethod |
| | def _get_special_ids(tokenizer) -> List[int]: |
| | cand = [ |
| | getattr(tokenizer, "pad_token_id", None), |
| | getattr(tokenizer, "cls_token_id", None), |
| | getattr(tokenizer, "sep_token_id", None), |
| | getattr(tokenizer, "bos_token_id", None), |
| | getattr(tokenizer, "eos_token_id", None), |
| | getattr(tokenizer, "mask_token_id", None), |
| | ] |
| | return sorted({int(x) for x in cand if x is not None}) |
| |
|
| | def _tokenize(self, smiles_list: List[str]) -> Dict[str, torch.Tensor]: |
| | tok = self.tokenizer( |
| | smiles_list, |
| | return_tensors="pt", |
| | padding=True, |
| | truncation=True, |
| | max_length=self.max_len, |
| | ) |
| | for k in tok: |
| | tok[k] = tok[k].to(self.device) |
| | if "attention_mask" not in tok: |
| | tok["attention_mask"] = torch.ones_like(tok["input_ids"], dtype=torch.long, device=self.device) |
| | return tok |
| |
|
| | @torch.no_grad() |
| | def pooled(self, smiles: str) -> torch.Tensor: |
| | s = smiles.strip() |
| | if self.use_cache and s in self._cache_pooled: |
| | return self._cache_pooled[s] |
| |
|
| | tok = self._tokenize([s]) |
| | ids = tok["input_ids"] |
| | attn = tok["attention_mask"].bool() |
| |
|
| | out = self.model(input_ids=ids, attention_mask=tok["attention_mask"]) |
| | h = out.last_hidden_state |
| |
|
| | valid = attn |
| | if self.special_ids_t is not None and self.special_ids_t.numel() > 0: |
| | valid = valid & (~_safe_isin(ids, self.special_ids_t)) |
| |
|
| | vf = valid.unsqueeze(-1).float() |
| | summed = (h * vf).sum(dim=1) |
| | denom = vf.sum(dim=1).clamp(min=1e-9) |
| | pooled = summed / denom |
| |
|
| | if self.use_cache: |
| | self._cache_pooled[s] = pooled |
| | return pooled |
| |
|
| | @torch.no_grad() |
| | def unpooled(self, smiles: str) -> Tuple[torch.Tensor, torch.Tensor]: |
| | """ |
| | Returns: |
| | X: (1, Li, H) float32 on device |
| | M: (1, Li) bool on device |
| | where Li excludes padding + special tokens. |
| | """ |
| | s = smiles.strip() |
| | if self.use_cache and s in self._cache_unpooled: |
| | return self._cache_unpooled[s] |
| |
|
| | tok = self._tokenize([s]) |
| | ids = tok["input_ids"] |
| | attn = tok["attention_mask"].bool() |
| |
|
| | out = self.model(input_ids=ids, attention_mask=tok["attention_mask"]) |
| | h = out.last_hidden_state |
| |
|
| | valid = attn |
| | if self.special_ids_t is not None and self.special_ids_t.numel() > 0: |
| | valid = valid & (~_safe_isin(ids, self.special_ids_t)) |
| |
|
| | |
| | keep = valid[0] |
| | X = h[:, keep, :] |
| | M = torch.ones((1, X.shape[1]), dtype=torch.bool, device=self.device) |
| |
|
| | if self.use_cache: |
| | self._cache_unpooled[s] = (X, M) |
| | return X, M |
| |
|
| |
|
| | class WTEmbedder: |
| | """ |
| | ESM2 embeddings for AA sequences. |
| | - pooled(): mean over tokens where attention_mask==1 AND token_id not in {CLS, EOS, PAD,...} |
| | - unpooled(): returns token embeddings filtered to valid tokens (specials removed), |
| | plus a 1-mask of length Li (since already filtered). |
| | """ |
| | def __init__( |
| | self, |
| | device: torch.device, |
| | esm_name: str = "facebook/esm2_t33_650M_UR50D", |
| | max_len: int = 1022, |
| | use_cache: bool = True, |
| | ): |
| | self.device = device |
| | self.max_len = max_len |
| | self.use_cache = use_cache |
| |
|
| | self.tokenizer = EsmTokenizer.from_pretrained(esm_name) |
| | self.model = EsmModel.from_pretrained(esm_name, add_pooling_layer=False).to(device).eval() |
| |
|
| | self.special_ids = self._get_special_ids(self.tokenizer) |
| | self.special_ids_t = (torch.tensor(self.special_ids, device=device, dtype=torch.long) |
| | if len(self.special_ids) else None) |
| |
|
| | self._cache_pooled: Dict[str, torch.Tensor] = {} |
| | self._cache_unpooled: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {} |
| |
|
| | @staticmethod |
| | def _get_special_ids(tokenizer) -> List[int]: |
| | cand = [ |
| | getattr(tokenizer, "pad_token_id", None), |
| | getattr(tokenizer, "cls_token_id", None), |
| | getattr(tokenizer, "sep_token_id", None), |
| | getattr(tokenizer, "bos_token_id", None), |
| | getattr(tokenizer, "eos_token_id", None), |
| | getattr(tokenizer, "mask_token_id", None), |
| | ] |
| | return sorted({int(x) for x in cand if x is not None}) |
| |
|
| | def _tokenize(self, seq_list: List[str]) -> Dict[str, torch.Tensor]: |
| | tok = self.tokenizer( |
| | seq_list, |
| | return_tensors="pt", |
| | padding=True, |
| | truncation=True, |
| | max_length=self.max_len, |
| | ) |
| | tok = {k: v.to(self.device) for k, v in tok.items()} |
| | if "attention_mask" not in tok: |
| | tok["attention_mask"] = torch.ones_like(tok["input_ids"], dtype=torch.long, device=self.device) |
| | return tok |
| |
|
| | @torch.no_grad() |
| | def pooled(self, seq: str) -> torch.Tensor: |
| | s = seq.strip() |
| | if self.use_cache and s in self._cache_pooled: |
| | return self._cache_pooled[s] |
| |
|
| | tok = self._tokenize([s]) |
| | ids = tok["input_ids"] |
| | attn = tok["attention_mask"].bool() |
| |
|
| | out = self.model(**tok) |
| | h = out.last_hidden_state |
| |
|
| | valid = attn |
| | if self.special_ids_t is not None and self.special_ids_t.numel() > 0: |
| | valid = valid & (~_safe_isin(ids, self.special_ids_t)) |
| |
|
| | vf = valid.unsqueeze(-1).float() |
| | summed = (h * vf).sum(dim=1) |
| | denom = vf.sum(dim=1).clamp(min=1e-9) |
| | pooled = summed / denom |
| |
|
| | if self.use_cache: |
| | self._cache_pooled[s] = pooled |
| | return pooled |
| |
|
| | @torch.no_grad() |
| | def unpooled(self, seq: str) -> Tuple[torch.Tensor, torch.Tensor]: |
| | """ |
| | Returns: |
| | X: (1, Li, H) float32 on device |
| | M: (1, Li) bool on device |
| | where Li excludes padding + special tokens. |
| | """ |
| | s = seq.strip() |
| | if self.use_cache and s in self._cache_unpooled: |
| | return self._cache_unpooled[s] |
| |
|
| | tok = self._tokenize([s]) |
| | ids = tok["input_ids"] |
| | attn = tok["attention_mask"].bool() |
| |
|
| | out = self.model(**tok) |
| | h = out.last_hidden_state |
| |
|
| | valid = attn |
| | if self.special_ids_t is not None and self.special_ids_t.numel() > 0: |
| | valid = valid & (~_safe_isin(ids, self.special_ids_t)) |
| |
|
| | keep = valid[0] |
| | X = h[:, keep, :] |
| | M = torch.ones((1, X.shape[1]), dtype=torch.bool, device=self.device) |
| |
|
| | if self.use_cache: |
| | self._cache_unpooled[s] = (X, M) |
| | return X, M |
| |
|
| |
|
| |
|
| | |
| | |
| | |
| | class PeptiVersePredictor: |
| | """ |
| | - loads best models from training_classifiers/ |
| | - computes embeddings as needed (pooled/unpooled) |
| | - supports: xgb, joblib(ENET/SVM/SVR), NN(mlp/cnn/transformer), binding pooled/unpooled. |
| | """ |
| | def __init__( |
| | self, |
| | manifest_path: str | Path, |
| | classifier_weight_root: str | Path, |
| | esm_name="facebook/esm2_t33_650M_UR50D", |
| | clm_name="aaronfeller/PeptideCLM-23M-all", |
| | smiles_vocab="tokenizer/new_vocab.txt", |
| | smiles_splits="tokenizer/new_splits.txt", |
| | device: Optional[str] = None, |
| | ): |
| | self.root = Path(classifier_weight_root) |
| | self.training_root = self.root / "training_classifiers" |
| | self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu")) |
| |
|
| | self.manifest = read_best_manifest_csv(manifest_path) |
| |
|
| | self.wt_embedder = WTEmbedder(self.device) |
| | self.smiles_embedder = SMILESEmbedder(self.device, clm_name=clm_name, |
| | vocab_path=str(self.root / smiles_vocab), |
| | splits_path=str(self.root / smiles_splits)) |
| |
|
| | self.models: Dict[Tuple[str, str], Any] = {} |
| | self.meta: Dict[Tuple[str, str], Dict[str, Any]] = {} |
| |
|
| | self._load_all_best_models() |
| |
|
| | def _resolve_dir(self, prop_key: str, model_name: str, mode: str) -> Path: |
| | """ |
| | Usual layout: training_classifiers/<prop>/<model>_<mode>/ |
| | Fallbacks: |
| | - training_classifiers/<prop>/<model>/ |
| | - training_classifiers/<prop>/<model>_wt |
| | """ |
| | base = self.training_root / prop_key |
| | candidates = [ |
| | base / f"{model_name}_{mode}", |
| | base / model_name, |
| | ] |
| | if mode == "wt": |
| | candidates += [base / f"{model_name}_wt"] |
| | if mode == "smiles": |
| | candidates += [base / f"{model_name}_smiles"] |
| |
|
| | for d in candidates: |
| | if d.exists(): |
| | return d |
| | raise FileNotFoundError(f"Cannot find model directory for {prop_key} {model_name} {mode}. Tried: {candidates}") |
| |
|
| | def _load_all_best_models(self): |
| | for prop_key, row in self.manifest.items(): |
| | for mode, label, thr in [ |
| | ("wt", row.best_wt, row.thr_wt), |
| | ("smiles", row.best_smiles, row.thr_smiles), |
| | ]: |
| | m = canon_model(label) |
| | if m is None: |
| | continue |
| |
|
| | |
| | if prop_key == "binding_affinity": |
| | |
| | pooled_or_unpooled = m |
| | folder = f"wt_{mode}_{pooled_or_unpooled}" |
| | model_dir = self.training_root / "binding_affinity" / folder |
| | art = find_best_artifact(model_dir) |
| | if art.suffix != ".pt": |
| | raise RuntimeError(f"Binding model expected best_model.pt, got {art}") |
| | model = load_binding_model(art, pooled_or_unpooled=pooled_or_unpooled, device=self.device) |
| | self.models[(prop_key, mode)] = model |
| | self.meta[(prop_key, mode)] = { |
| | "task_type": "Regression", |
| | "threshold": None, |
| | "artifact": str(art), |
| | "model_name": pooled_or_unpooled, |
| | } |
| | continue |
| |
|
| | model_dir = self._resolve_dir(prop_key, m, mode) |
| | kind, obj, art = load_artifact(model_dir, self.device) |
| |
|
| | if kind in {"xgb", "joblib"}: |
| | self.models[(prop_key, mode)] = obj |
| | else: |
| | |
| | self.models[(prop_key, mode)] = build_torch_model_from_ckpt(m, obj, self.device) |
| |
|
| | self.meta[(prop_key, mode)] = { |
| | "task_type": row.task_type, |
| | "threshold": thr, |
| | "artifact": str(art), |
| | "model_name": m, |
| | "kind": kind, |
| | } |
| |
|
| | def _get_features_for_model(self, prop_key: str, mode: str, input_str: str): |
| | """ |
| | Returns either: |
| | - pooled np array shape (1,H) for xgb/joblib |
| | - unpooled torch tensors (X,M) for NN |
| | """ |
| | model = self.models[(prop_key, mode)] |
| | meta = self.meta[(prop_key, mode)] |
| | kind = meta.get("kind", None) |
| | model_name = meta.get("model_name", "") |
| |
|
| | if prop_key == "binding_affinity": |
| | raise RuntimeError("Use predict_binding_affinity().") |
| |
|
| | |
| | if kind == "torch_ckpt": |
| | if mode == "wt": |
| | X, M = self.wt_embedder.unpooled(input_str) |
| | else: |
| | X, M = self.smiles_embedder.unpooled(input_str) |
| | return X, M |
| |
|
| | |
| | if mode == "wt": |
| | v = self.wt_embedder.pooled(input_str) |
| | else: |
| | v = self.smiles_embedder.pooled(input_str) |
| | feats = v.detach().cpu().numpy().astype(np.float32) |
| | feats = np.nan_to_num(feats, nan=0.0) |
| | feats = np.clip(feats, np.finfo(np.float32).min, np.finfo(np.float32).max) |
| | return feats |
| |
|
| | def predict_property(self, prop_key: str, mode: str, input_str: str) -> Dict[str, Any]: |
| | """ |
| | mode: "wt" for AA sequence input, "smiles" for SMILES input |
| | Returns dict with score + label if classifier threshold exists. |
| | """ |
| | if (prop_key, mode) not in self.models: |
| | raise KeyError(f"No model loaded for ({prop_key}, {mode}). Check manifest and folders.") |
| |
|
| | meta = self.meta[(prop_key, mode)] |
| | model = self.models[(prop_key, mode)] |
| | task_type = meta["task_type"].lower() |
| | thr = meta.get("threshold", None) |
| | kind = meta.get("kind", None) |
| |
|
| | if prop_key == "binding_affinity": |
| | raise RuntimeError("Use predict_binding_affinity().") |
| |
|
| | |
| | if kind == "torch_ckpt": |
| | X, M = self._get_features_for_model(prop_key, mode, input_str) |
| | with torch.no_grad(): |
| | y = model(X, M).squeeze().float().cpu().item() |
| | if task_type == "classifier": |
| | prob = float(1.0 / (1.0 + np.exp(-y))) |
| | out = {"property": prop_key, "mode": mode, "score": prob} |
| | if thr is not None: |
| | out["label"] = int(prob >= float(thr)) |
| | out["threshold"] = float(thr) |
| | return out |
| | else: |
| | return {"property": prop_key, "mode": mode, "score": float(y)} |
| |
|
| | |
| | if kind == "xgb": |
| | feats = self._get_features_for_model(prop_key, mode, input_str) |
| | dmat = xgb.DMatrix(feats) |
| | pred = float(model.predict(dmat)[0]) |
| | out = {"property": prop_key, "mode": mode, "score": pred} |
| | if task_type == "classifier" and thr is not None: |
| | out["label"] = int(pred >= float(thr)) |
| | out["threshold"] = float(thr) |
| | return out |
| |
|
| | |
| | if kind == "joblib": |
| | feats = self._get_features_for_model(prop_key, mode, input_str) |
| | |
| | if task_type == "classifier": |
| | if hasattr(model, "predict_proba"): |
| | pred = float(model.predict_proba(feats)[:, 1][0]) |
| | else: |
| | if hasattr(model, "decision_function"): |
| | logit = float(model.decision_function(feats)[0]) |
| | pred = float(1.0 / (1.0 + np.exp(-logit))) |
| | else: |
| | pred = float(model.predict(feats)[0]) |
| | out = {"property": prop_key, "mode": mode, "score": pred} |
| | if thr is not None: |
| | out["label"] = int(pred >= float(thr)) |
| | out["threshold"] = float(thr) |
| | return out |
| | else: |
| | pred = float(model.predict(feats)[0]) |
| | return {"property": prop_key, "mode": mode, "score": pred} |
| |
|
| | raise RuntimeError(f"Unknown model kind={kind}") |
| |
|
| | def predict_binding_affinity(self, mode: str, target_seq: str, binder_str: str) -> Dict[str, Any]: |
| | """ |
| | mode: "wt" (binder is AA sequence) -> wt_wt_(pooled|unpooled) |
| | "smiles" (binder is SMILES) -> wt_smiles_(pooled|unpooled) |
| | """ |
| | prop_key = "binding_affinity" |
| | if (prop_key, mode) not in self.models: |
| | raise KeyError(f"No binding model loaded for ({prop_key}, {mode}).") |
| |
|
| | model = self.models[(prop_key, mode)] |
| | pooled_or_unpooled = self.meta[(prop_key, mode)]["model_name"] |
| |
|
| | |
| | if pooled_or_unpooled == "pooled": |
| | t_vec = self.wt_embedder.pooled(target_seq) |
| | if mode == "wt": |
| | b_vec = self.wt_embedder.pooled(binder_str) |
| | else: |
| | b_vec = self.smiles_embedder.pooled(binder_str) |
| | with torch.no_grad(): |
| | reg, logits = model(t_vec, b_vec) |
| | affinity = float(reg.squeeze().cpu().item()) |
| | cls_logit = int(torch.argmax(logits, dim=-1).cpu().item()) |
| | cls_thr = affinity_to_class(affinity) |
| | else: |
| | T, Mt = self.wt_embedder.unpooled(target_seq) |
| | if mode == "wt": |
| | B, Mb = self.wt_embedder.unpooled(binder_str) |
| | else: |
| | B, Mb = self.smiles_embedder.unpooled(binder_str) |
| | with torch.no_grad(): |
| | reg, logits = model(T, Mt, B, Mb) |
| | affinity = float(reg.squeeze().cpu().item()) |
| | cls_logit = int(torch.argmax(logits, dim=-1).cpu().item()) |
| | cls_thr = affinity_to_class(affinity) |
| |
|
| | names = {0: "High (≥9)", 1: "Moderate (7–9)", 2: "Low (<7)"} |
| | return { |
| | "property": "binding_affinity", |
| | "mode": mode, |
| | "affinity": affinity, |
| | "class_by_threshold": names[cls_thr], |
| | "class_by_logits": names[cls_logit], |
| | "binding_model": pooled_or_unpooled, |
| | } |
| |
|
| |
|
| | |
| | |
| | |
| | if __name__ == "__main__": |
| | |
| | predictor = PeptiVersePredictor( |
| | manifest_path="best_models.txt", |
| | classifier_weight_root="/vast/projects/pranam/lab/yz927/projects/Classifier_Weight" |
| | ) |
| | print(predictor.predict_property("hemolysis", "wt", "GIGAVLKVLTTGLPALISWIKRKRQQ")) |
| | print(predictor.predict_binding_affinity("wt", target_seq="...", binder_str="...")) |
| |
|
| | |
| | """ |
| | device = torch.device("cuda:0") |
| | |
| | wt = WTEmbedder(device) |
| | sm = SMILESEmbedder(device, |
| | vocab_path="/home/enol/PeptideGym/Data_split/tokenizer/new_vocab.txt", |
| | splits_path="/home/enol/PeptideGym/Data_split/tokenizer/new_splits.txt" |
| | ) |
| | |
| | p = wt.pooled("GIGAVLKVLTTGLPALISWIKRKRQQ") # (1,1280) |
| | X, M = wt.unpooled("GIGAVLKVLTTGLPALISWIKRKRQQ") # (1,Li,1280), (1,Li) |
| | |
| | p2 = sm.pooled("NCC(=O)N[C@H](CS)C(=O)O") # (1,H_smiles) |
| | X2, M2 = sm.unpooled("NCC(=O)N[C@H](CS)C(=O)O") # (1,Li,H_smiles), (1,Li) |
| | """ |
| |
|