OTRec / app.py
GrimSqueaker's picture
Upload folder using huggingface_hub
5d5d5e8 verified
import os
import pandas as pd
import numpy as np
import tensorflow as tf
from tensorflow import keras
# from keras.layers import ...
from huggingface_hub import hf_hub_download
import gradio as gr
import h5py
from dl_model_def import make_fs, TwoTowerDual, build_two_tower_model
# ============================================
# CONFIG
# ============================================
DATA_DIR = "./data/proc"
# Download the model weights from your specific HF Repo
print("Downloading model weights from Hugging Face Hub...")
WEIGHTS_FILE = hf_hub_download(
repo_id="GrimSqueaker/OTRec",
filename="model.weights.h5"
)
print(f"Weights downloaded to: {WEIGHTS_FILE}")
# ============================================
# LOAD TRAINING DATA
# ============================================
df_learn = pd.read_parquet(f"{DATA_DIR}/df_learn_sub.parquet")
disease_df = pd.read_parquet(f"{DATA_DIR}/disease_df.parquet")
target_df = pd.read_parquet(f"{DATA_DIR}/target_df.parquet")
# Ensure column names match training
df_learn = df_learn.rename(columns={
"disease_text_embed": "disease_text",
"target_text_embed": "target_text"
}, errors="ignore")
disease_df.rename(columns={"disease_text_embed": "disease_text"}, errors="ignore",inplace=True)
target_df.rename(columns={"target_text_embed":"target_text"}, errors="ignore",inplace=True)
# ============================================
# BUILD MODEL + LOAD WEIGHTS
# ============================================
print("Building TwoTowerDual...")
# 1. Reset Keras Session to ensure layer names start at index 0 (matches clean training)
tf.keras.backend.clear_session()
# 2. Rebuild architecture
model = build_two_tower_model(df_learn)
print("Loading weights...")
try:
# Try standard load
model.load_weights(WEIGHTS_FILE)
except ValueError as e:
print(f"Standard load failed ({e}). Attempting name-mismatch fix...")
# FALLBACK: The training notebook likely generated layer names like 'dise_emb_1'
# due to multiple runs. We inspect the .h5 file and map the names.
with h5py.File(WEIGHTS_FILE, 'r') as f:
h5_keys = list(f.keys())
print(f"Weights file contains layers: {h5_keys}")
# Helper to find the matching key in h5 file for a given prefix
def match_layer_name(target_attr, prefix):
# Find key in h5 that starts with prefix (e.g. 'dise_emb')
match = next((k for k in h5_keys if k.startswith(prefix)), None)
if match and hasattr(model, target_attr):
layer = getattr(model, target_attr)
print(f"Renaming model layer '{layer.name}' to '{match}' to match file.")
layer._name = match
# Apply renames for known components
match_layer_name('dise_emb', 'dise_emb')
match_layer_name('q_tower', 'tower') # Attempt to catch tower/tower_1
# k_tower might share the name 'tower' prefix in H5, which is tricky in subclasses
# usually save_weights on subclass saves attributes directly.
# Retry load after renaming
model.load_weights(WEIGHTS_FILE)
print("Weights loaded successfully.")
# ============================================
# PRECOMPUTE CANDIDATE EMBEDDINGS
# ============================================
# # Note: In TF 2.16+, Ensure inputs are tf.constant or numpy compatible
# cand_embs = model.encode_k(target_texts, target_ids)
# cand_embs = tf.nn.l2_normalize(cand_embs, axis=1).numpy()
# print("Candidate embeddings ready.")
print("Precomputing candidate embeddings (batched)...")
target_texts = target_df["target_text"].astype(str).to_numpy()
target_ids = target_df["targetId"].astype(str).to_numpy()
# FIX: Process in batches to avoid OOM
BATCH_SIZE = 1024 # Conservative batch size for wide inputs
cand_embs_list = []
total = len(target_texts)
for i in range(0, total, BATCH_SIZE):
# Slice the batch
end = min(i + BATCH_SIZE, total)
batch_txt = target_texts[i:end]
batch_ids = target_ids[i:end]
# Run inference on the batch (keeps memory usage low)
# Using tf.device conversion is optional but good for safety if GPU is fragmented
emb_batch = model.encode_k(batch_txt, batch_ids)
cand_embs_list.append(emb_batch)
if i % 5000 == 0:
print(f" Processed {i}/{total} candidates...")
# Concatenate all batches back into one tensor
cand_embs = tf.concat(cand_embs_list, axis=0)
# Normalize the final result
cand_embs = tf.nn.l2_normalize(cand_embs, axis=1).numpy()
print(f"Candidate embeddings ready. Shape: {cand_embs.shape}")
# ============================================
# RECOMMENDATION FUNCTION
# ============================================
def recommend_targets(disease_id, top_k=10):
# 1. Validate Input
if not disease_id:
return pd.DataFrame(), None
row = disease_df.loc[disease_df["diseaseId"] == disease_id]
if row.empty:
return pd.DataFrame(), None
# 2. Encode Query
disease_text = row["disease_text"].iloc[0]
q_emb = model.encode_q(
tf.constant([disease_text]),
tf.constant([disease_id])
)
q_emb = tf.nn.l2_normalize(q_emb, axis=1).numpy()[0]
# 3. Calculate Raw Cosine Similarity
# Shape: (N_targets,)
raw_sim = cand_embs @ q_emb
# 4. Convert to Probability (Fixes negative scores)
# The model has a trained 'cls_head' (Sigmoid) that maps Similarity -> Probability
# We reshape to (N, 1) because the Keras Dense layer expects a matrix
scores = model.cls_head(raw_sim.reshape(-1, 1)).numpy().flatten()
# 5. Get Top K
k = int(top_k)
idx = np.argsort(scores)[::-1][:k]
# 6. Build Result DataFrame
results = target_df.iloc[idx].copy()
# Force standard python float for clean rounding
raw_scores = scores[idx]
results["score"] = [round(float(x), 4) for x in raw_scores]
# 7. Select Columns
desc_col = "functionDescription" if "functionDescription" in results.columns else "functionDescriptions"
desired_cols = [
"targetId",
"approvedSymbol",
"approvedName",
desc_col,
"score"
]
final_cols = [c for c in desired_cols if c in results.columns]
results = results[final_cols]
# 8. Save to CSV for download
csv_path = "recommendations.csv"
results.to_csv(csv_path, index=False)
return results, csv_path
# ============================================
# GRADIO APP
# ============================================
def search_diseases(query):
if not query or len(query) < 2:
return gr.update(choices=[], value=None)
mask = (
disease_df["name"].str.contains(query, case=False, na=False) |
disease_df["diseaseId"].str.contains(query, case=False, na=False)
)
matches = disease_df.loc[mask].head(30)
choices = [
(f"{row['name']} ({row['diseaseId']})", row['diseaseId'])
for _, row in matches.iterrows()
]
first_val = choices[0][1] if choices else None
return gr.update(choices=choices, value=first_val)
def launch():
examples = ["synuclein", "diabetes", "doid_0050890"]
with gr.Blocks() as demo:
gr.Markdown("# Disease → Target Recommender")
gr.Markdown("Search for a disease by **Name** or **ID** to get target recommendations.")
with gr.Row():
search_box = gr.Textbox(
label="1. Search Disease",
placeholder="Type name (e.g., 'Parkinson') or ID...",
lines=1
)
did_dropdown = gr.Dropdown(
label="2. Select Disease",
choices=[],
interactive=True
)
topk = gr.Slider(1, 400, value=10, step=5, label="Top K Targets")
# Search Logic (Updates dropdown options and default value)
search_box.change(fn=search_diseases, inputs=search_box, outputs=did_dropdown)
# Output Components (Stacked vertically for full width)
out_df = gr.Dataframe(
label="Predictions",
interactive=False,
wrap=True,
show_search="filter",
)
out_file = gr.File(label="Download CSV")
# === TRIGGER LOGIC ===
# 1. Manual Trigger (Keep the button just in case)
btn = gr.Button("Recommend Targets", variant="primary")
btn.click(
fn=recommend_targets,
inputs=[did_dropdown, topk],
outputs=[out_df, out_file]
)
# 2. Auto-Trigger on Change
# This handles the Examples too: Example -> Search -> Dropdown Update -> Trigger
did_dropdown.change(
fn=recommend_targets,
inputs=[did_dropdown, topk],
outputs=[out_df, out_file]
)
# Also update when slider moves
topk.change(
fn=recommend_targets,
inputs=[did_dropdown, topk],
outputs=[out_df, out_file]
)
gr.Examples(examples=examples, inputs=search_box)
demo.launch()
if __name__ == "__main__":
launch()