"""Self-contained inference module for the recommendation web app. Contains trimmed copies of ``MLPMetric`` and ``MLPMetricFull`` (and their dependencies) so HF Spaces deployments do not need to ship the full ``module/`` package. The class layout and parameter names match the trained checkpoint exactly, so the original ``state_dict`` loads with ``strict=False`` and a clean diff. """ from __future__ import annotations import hashlib import math import re from types import SimpleNamespace from typing import Optional import torch import torch.nn as nn class ModelNameAvgEncoder(nn.Module): """Hashed-token average over a model name. Optionally adds an ID embedding.""" def __init__(self, args, hash_buckets: int = 10000): super().__init__() self.hash_buckets = hash_buckets self.tok_emb = nn.Embedding(self.hash_buckets, args.token_dim) self.use_id_emb = bool(getattr(args, "use_id_emb", False)) if self.use_id_emb: self.id_emb = nn.Embedding(args.num_models + 1, args.model_dim) self.unk_model_id = args.num_models @staticmethod def _split(name: str): n = (name or "").strip().lower() if not n: return [] toks = [n] if "/" in n: toks.append(n.split("/")[-1]) toks.extend([t for t in re.split(r"[\/_\-\s]+", n) if t]) out, seen = [], set() for t in toks: if t in seen: continue out.append(t) seen.add(t) return out def _hash(self, tok: str): return int(hashlib.md5(tok.encode()).hexdigest(), 16) % self.hash_buckets def forward(self, model_ids: torch.LongTensor, model_names: list[str]): device = self.tok_emb.weight.device vecs = [] for n in model_names: toks = self._split(n) if not toks: vecs.append(torch.zeros(self.tok_emb.embedding_dim, device=device)) continue idxs = torch.tensor([self._hash(t) for t in toks], device=device, dtype=torch.long) vecs.append(self.tok_emb(idxs).mean(dim=0)) h_name = torch.stack(vecs, dim=0) feats = [h_name] if self.use_id_emb: feats.append(self.id_emb(model_ids.to(device))) return torch.cat(feats, dim=-1) class MLPMetric(nn.Module): """MLP recommender that takes raw dataset description embeddings, plus task / metric / size / family side features, and ranks model candidates. Mirrors the checkpoint at ``checkpoint/mlp/unified_augmented/ablation_no_model_id_no_dataset_id``. """ def __init__(self, args): super().__init__() self.use_id_emb = bool(getattr(args, "use_id_emb", False)) if self.use_id_emb: self.model_embedding = nn.Embedding(args.num_models, args.model_dim) else: self.model_embedding = None self.task_embedding = nn.Embedding(args.num_tasks, args.task_dim) self.model_info_encoder = ModelNameAvgEncoder(args) self.size_embedding = nn.Embedding(args.num_size_buckets, args.size_dim) self.num_size_buckets = int(args.num_size_buckets) self.use_size_prior = bool(getattr(args, "use_size_prior", True)) self.use_family_prior = bool(getattr(args, "use_family_prior", False)) if self.use_family_prior: family_dim = int(getattr(args, "family_dim", args.size_dim)) self.family_embedding = nn.Embedding(args.num_families, family_dim) self.family_dim = family_dim else: self.family_dim = 0 # Disable Model-Spider fusion path entirely (not used by this checkpoint). self.use_ms_spider_repr = False self.ms_fusion_dim = 0 model_info_dim = args.token_dim + (args.model_dim if self.use_id_emb else 0) dataset_info_dim = args.dataset_desp_dim + args.task_dim backbone_in_dim = ( model_info_dim + dataset_info_dim + args.size_dim + self.family_dim + self.ms_fusion_dim ) # Backbone is rebuilt by the metric branch below; the base layers are kept here # to match the parameter naming of the saved state dict. self.backbone = nn.Sequential( nn.Linear(backbone_in_dim, args.hidden_dim), nn.ReLU(), nn.Dropout(args.dropout_rate), nn.Linear(args.hidden_dim, args.hidden_dim), nn.ReLU(), nn.Dropout(args.dropout_rate), ) self.pairwise_head = nn.Linear(args.hidden_dim, 1) self.pointwise_head = nn.Linear(args.hidden_dim, 1) prior_in_dim = args.size_dim + self.family_dim self.prior_head = nn.Sequential( nn.Linear(prior_in_dim, args.hidden_dim // 2), nn.ReLU(), nn.Linear(args.hidden_dim // 2, 1), ) self.temperature = nn.Parameter(torch.tensor(1.0)) # ---- metric extension (matches the MLPMetric subclass) ---- self.use_metric_embedding = bool(getattr(args, "use_metric_feature", True)) self.num_metrics = int(getattr(args, "num_metrics", 1)) self.metric_dim = int(getattr(args, "metric_dim", args.task_dim)) self.unknown_metric_id = int(getattr(args, "unknown_metric_id", 0)) if self.use_metric_embedding: self.metric_embedding = nn.Embedding(max(self.num_metrics, 1), self.metric_dim) in_features = self.backbone[0].in_features + self.metric_dim hidden = self.backbone[0].out_features dropout = self.backbone[2].p self.backbone = nn.Sequential( nn.Linear(in_features, hidden), nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden, hidden), nn.ReLU(), nn.Dropout(dropout), ) else: self.metric_embedding = None def encode_model(self, model_ids: torch.LongTensor, model_names: list[str]) -> torch.Tensor: return self.model_info_encoder(model_ids, model_names) @torch.no_grad() def build_model_cache( self, all_model_names: list[str], all_model_size_ids: torch.LongTensor, all_model_family_ids: Optional[torch.LongTensor] = None, device=None, ): if device is None: device = next(self.parameters()).device size_ids = all_model_size_ids.to(device=device, dtype=torch.long) M = len(all_model_names) assert size_ids.shape[0] == M model_ids = torch.arange(M, device=device, dtype=torch.long) h_model = self.encode_model(model_ids, all_model_names) h_size = self.size_embedding(size_ids) cache = {"h_model": h_model, "h_size": h_size, "size_ids": size_ids} if self.use_family_prior and all_model_family_ids is not None: family_ids = all_model_family_ids.to(device=device, dtype=torch.long) cache["h_family"] = self.family_embedding(family_ids) cache["family_ids"] = family_ids else: cache["h_family"] = None cache["family_ids"] = None return cache def _metric_embed( self, metric_ids: Optional[torch.LongTensor], batch_size: int, device ) -> Optional[torch.Tensor]: if not self.use_metric_embedding or self.metric_embedding is None: return None if metric_ids is None: metric_ids = torch.full( (batch_size,), int(self.unknown_metric_id), dtype=torch.long, device=device ) return self.metric_embedding(metric_ids) @torch.no_grad() def score_matrix( self, task_ids: torch.LongTensor, dataset_desp_batch: torch.Tensor, model_cache: dict, metric_ids: Optional[torch.LongTensor] = None, chunk_size: int = 8192, ) -> torch.Tensor: device = dataset_desp_batch.device B = dataset_desp_batch.size(0) h_task = self.task_embedding(task_ids) h_data = dataset_desp_batch h_metric = self._metric_embed(metric_ids, B, device) h_model_all = model_cache["h_model"] h_size_all = model_cache["h_size"] h_family_all = model_cache.get("h_family") M = h_model_all.size(0) if self.use_size_prior or self.use_family_prior: if h_family_all is not None: prior_inp_all = torch.cat([h_size_all, h_family_all], dim=-1) else: prior_inp_all = h_size_all prior_all = self.prior_head(prior_inp_all).squeeze(-1) else: prior_all = torch.zeros(M, device=device) out = torch.empty(B, M, device=device) T = torch.clamp(self.temperature, min=1e-3) start = 0 while start < M: end = min(start + chunk_size, M) m = end - start h_model = h_model_all[start:end] h_size = h_size_all[start:end] h_model_exp = h_model.unsqueeze(0).expand(B, m, -1) h_size_exp = h_size.unsqueeze(0).expand(B, m, -1) h_data_exp = h_data.unsqueeze(1).expand(B, m, -1) h_task_exp = h_task.unsqueeze(1).expand(B, m, -1) parts = [h_model_exp, h_data_exp, h_size_exp] if h_family_all is not None: h_family_exp = h_family_all[start:end].unsqueeze(0).expand(B, m, -1) parts.append(h_family_exp) parts.append(h_task_exp) if h_metric is not None: parts.append(h_metric.unsqueeze(1).expand(B, m, -1)) residual_inp = torch.cat(parts, dim=-1) h = self.backbone(residual_inp.reshape(B * m, -1)) s_chunk = self.pairwise_head(h).reshape(B, m) prior_chunk = prior_all[start:end].unsqueeze(0) out[:, start:end] = (s_chunk + prior_chunk) / T start = end return out class MLPMetricFull(MLPMetric): """Full-feature recommender. Uses model-id emb, model-name emb, model-desc emb, dataset-id emb, and dataset-desc emb. For inference on a *new user dataset* (no global dataset_id), we: - feed UNK as dataset_id (so dataset_id_embedding still contributes a learned [UNK] prior), - feed the user's OpenAI embedding directly as dataset_desc_emb, bypassing the training-time ``dataset_desc_matrix`` lookup. Parameter layout matches the training-time class so the state_dict loads via ``load_state_dict(strict=False)`` after stripping the buffers that are only useful for the train-set IDs. """ def __init__(self, args): # ---- dim bookkeeping ---- self.dataset_id_emb_dim = int(getattr(args, "dataset_id_emb_dim", 256)) self.dataset_desp_emb_dim = int(getattr(args, "dataset_desp_emb_dim", 1536)) self.model_desp_emb_dim = int(getattr(args, "model_desp_emb_dim", 1536)) # Information-source flags (kept for parity; defaults match training) self.use_model_id_emb = bool(getattr(args, "use_model_id_emb", True)) self.use_model_name_emb = bool(getattr(args, "use_model_name_emb", True)) self.use_model_desc_emb = bool(getattr(args, "use_model_desc_emb", True)) self.use_dataset_id_emb = bool(getattr(args, "use_dataset_id_emb", True)) self.use_dataset_desc_emb = bool(getattr(args, "use_dataset_desc_emb", True)) self.use_size_feature = bool(getattr(args, "use_size_feature", True)) # The parent's __init__ builds task/size/family/metric embeddings, # prior_head, temperature, plus a placeholder backbone (which we rebuild). # Set dataset_desp_dim so parent sizes its placeholder correctly; we # don't actually use the parent's backbone — we rebuild it below. orig_desp_dim = args.dataset_desp_dim args.dataset_desp_dim = self.dataset_id_emb_dim + self.dataset_desp_emb_dim super().__init__(args) args.dataset_desp_dim = orig_desp_dim # ==== Model-side components (own name encoder + own id emb) ==== if self.use_model_name_emb: args_name_only = SimpleNamespace(**vars(args)) args_name_only.use_id_emb = False self._name_encoder = ModelNameAvgEncoder(args_name_only) else: self._name_encoder = None if self.use_model_id_emb: self._id_emb = nn.Embedding(args.num_models + 1, args.model_dim) self.unk_model_id = args.num_models else: self._id_emb = None self.unk_model_id = 0 # Model-description buffer: one row per known model. if self.use_model_desc_emb: self.register_buffer( "model_desc_matrix", torch.zeros(args.num_models, self.model_desp_emb_dim), ) else: self.register_buffer( "model_desc_matrix", torch.zeros(0, self.model_desp_emb_dim), ) # ==== Dataset-side components ==== num_datasets = int(getattr(args, "num_datasets", 100000)) if self.use_dataset_id_emb: # +2: one for [UNK], one for the upper slack (matches training) self.dataset_id_embedding = nn.Embedding(num_datasets + 2, self.dataset_id_emb_dim) self.unk_dataset_id = num_datasets + 1 else: self.dataset_id_embedding = None self.unk_dataset_id = 0 # ``dataset_desc_matrix`` is NOT registered at inference time — we use # the user's OpenAI embedding directly. The stripped checkpoint also # omits this buffer. # ==== Recompute backbone input dim and rebuild ==== model_info_dim = ( (args.token_dim if self.use_model_name_emb else 0) + (args.model_dim if self.use_model_id_emb else 0) + (self.model_desp_emb_dim if self.use_model_desc_emb else 0) ) self.model_info_dim = model_info_dim dataset_emb_dim = ( (self.dataset_id_emb_dim if self.use_dataset_id_emb else 0) + (self.dataset_desp_emb_dim if self.use_dataset_desc_emb else 0) ) self.dataset_emb_dim = dataset_emb_dim dataset_info_dim = dataset_emb_dim + args.task_dim metric_dim = self.metric_dim if self.use_metric_embedding else 0 size_emb_dim_eff = args.size_dim if self.use_size_feature else 0 backbone_in = ( model_info_dim + dataset_info_dim + size_emb_dim_eff + self.family_dim + metric_dim ) self.backbone = nn.Sequential( nn.Linear(backbone_in, args.hidden_dim), nn.ReLU(), nn.Dropout(args.dropout_rate), nn.Linear(args.hidden_dim, args.hidden_dim), nn.ReLU(), nn.Dropout(args.dropout_rate), ) prior_in_actual = 0 if self.use_size_prior and self.use_size_feature: prior_in_actual += args.size_dim if self.use_family_prior: prior_in_actual += self.family_dim if prior_in_actual > 0: self.prior_head = nn.Sequential( nn.Linear(prior_in_actual, args.hidden_dim // 2), nn.ReLU(), nn.Linear(args.hidden_dim // 2, 1), ) # ------------------------------------------------------------------ # Model-side encoding (used by build_model_cache) # ------------------------------------------------------------------ def encode_model( self, model_ids: torch.LongTensor, model_names: list[str], ) -> torch.Tensor: B = model_ids.shape[0] device = model_ids.device parts = [] if self.use_model_name_emb: parts.append(self._name_encoder(model_ids, model_names)) if self.use_model_id_emb: parts.append(self._id_emb(model_ids)) if self.use_model_desc_emb: if self.model_desc_matrix.shape[0] > 0: safe_ids = model_ids.clamp(0, self.model_desc_matrix.shape[0] - 1) parts.append(self.model_desc_matrix[safe_ids]) else: parts.append(torch.zeros(B, self.model_desp_emb_dim, device=device)) if not parts: return torch.zeros(B, 0, device=device) if len(parts) == 1: return parts[0] return torch.cat(parts, dim=-1) @torch.no_grad() def build_model_cache( self, all_model_names: list[str], all_model_size_ids: torch.LongTensor, all_model_family_ids: Optional[torch.LongTensor] = None, device=None, ): if device is None: device = next(self.parameters()).device size_ids = all_model_size_ids.to(device=device, dtype=torch.long) M = len(all_model_names) assert size_ids.shape[0] == M model_ids = torch.arange(M, device=device, dtype=torch.long) h_model = self.encode_model(model_ids, all_model_names) h_size = self.size_embedding(size_ids) if self.use_size_feature else None cache = {"h_model": h_model, "h_size": h_size, "size_ids": size_ids} if self.use_family_prior and all_model_family_ids is not None: family_ids = all_model_family_ids.to(device=device, dtype=torch.long) cache["h_family"] = self.family_embedding(family_ids) cache["family_ids"] = family_ids else: cache["h_family"] = None cache["family_ids"] = None return cache # ------------------------------------------------------------------ # Dataset-side encoding for inference: user's OpenAI emb + UNK id # ------------------------------------------------------------------ def _encode_dataset_at_inference( self, dataset_desp_emb: torch.Tensor, ) -> torch.Tensor: """``dataset_desp_emb`` is the user's OpenAI vector of shape ``[B, dataset_desp_emb_dim]``. We pair it with a learned [UNK] dataset-id embedding, matching the training-time concatenation order (id_emb || desc_emb). """ B = dataset_desp_emb.shape[0] device = dataset_desp_emb.device parts = [] if self.use_dataset_id_emb and self.dataset_id_embedding is not None: unk = torch.full((B,), int(self.unk_dataset_id), dtype=torch.long, device=device) parts.append(self.dataset_id_embedding(unk)) if self.use_dataset_desc_emb: parts.append(dataset_desp_emb) if not parts: return torch.zeros(B, 0, device=device) if len(parts) == 1: return parts[0] return torch.cat(parts, dim=-1) # ------------------------------------------------------------------ # score_matrix at inference time # ------------------------------------------------------------------ @torch.no_grad() def score_matrix( self, task_ids: torch.LongTensor, dataset_desp_batch: torch.Tensor, model_cache: dict, metric_ids: Optional[torch.LongTensor] = None, chunk_size: int = 8192, ) -> torch.Tensor: """``dataset_desp_batch`` here is the OpenAI embedding ``[B, 1536]``.""" device = dataset_desp_batch.device B = dataset_desp_batch.size(0) h_task = self.task_embedding(task_ids) h_data = self._encode_dataset_at_inference(dataset_desp_batch) h_metric = self._metric_embed(metric_ids, B, device) h_model_all = model_cache["h_model"] h_size_all = model_cache["h_size"] if self.use_size_feature else None h_family_all = model_cache.get("h_family") M = h_model_all.size(0) prior_parts_all = [] if self.use_size_prior and h_size_all is not None: prior_parts_all.append(h_size_all) if self.use_family_prior and h_family_all is not None: prior_parts_all.append(h_family_all) if prior_parts_all: prior_inp_all = ( torch.cat(prior_parts_all, dim=-1) if len(prior_parts_all) > 1 else prior_parts_all[0] ) prior_all = self.prior_head(prior_inp_all).squeeze(-1) else: prior_all = torch.zeros(M, device=device) out = torch.empty(B, M, device=device) T = torch.clamp(self.temperature, min=1e-3) start = 0 while start < M: end = min(start + chunk_size, M) m = end - start h_model = h_model_all[start:end] h_model_exp = h_model.unsqueeze(0).expand(B, m, -1) if h_model.shape[1] > 0 else None h_data_exp = h_data.unsqueeze(1).expand(B, m, -1) if h_data.shape[1] > 0 else None h_task_exp = h_task.unsqueeze(1).expand(B, m, -1) h_size_exp = ( h_size_all[start:end].unsqueeze(0).expand(B, m, -1) if h_size_all is not None else None ) h_metric_exp = ( h_metric.unsqueeze(1).expand(B, m, -1) if h_metric is not None else None ) parts = [] if h_model_exp is not None: parts.append(h_model_exp) if h_data_exp is not None: parts.append(h_data_exp) if h_size_exp is not None: parts.append(h_size_exp) if h_family_all is not None: h_family_exp = h_family_all[start:end].unsqueeze(0).expand(B, m, -1) parts.append(h_family_exp) parts.append(h_task_exp) if h_metric_exp is not None: parts.append(h_metric_exp) residual_inp = torch.cat(parts, dim=-1) h = self.backbone(residual_inp.reshape(B * m, -1)) s_chunk = self.pairwise_head(h).reshape(B, m) prior_chunk = prior_all[start:end].unsqueeze(0) out[:, start:end] = (s_chunk + prior_chunk) / T start = end return out