Spaces:
Running
Running
| import os | |
| import pandas as pd | |
| import numpy as np | |
| import tensorflow as tf | |
| from tensorflow import keras | |
| # from keras.layers import ... | |
| from huggingface_hub import hf_hub_download | |
| import gradio as gr | |
| import h5py | |
| from dl_model_def import make_fs, TwoTowerDual, build_two_tower_model | |
| # ============================================ | |
| # CONFIG | |
| # ============================================ | |
| DATA_DIR = "./data/proc" | |
| # Download the model weights from your specific HF Repo | |
| print("Downloading model weights from Hugging Face Hub...") | |
| WEIGHTS_FILE = hf_hub_download( | |
| repo_id="GrimSqueaker/OTRec", | |
| filename="model.weights.h5" | |
| ) | |
| print(f"Weights downloaded to: {WEIGHTS_FILE}") | |
| # ============================================ | |
| # LOAD TRAINING DATA | |
| # ============================================ | |
| df_learn = pd.read_parquet(f"{DATA_DIR}/df_learn_sub.parquet") | |
| disease_df = pd.read_parquet(f"{DATA_DIR}/disease_df.parquet") | |
| target_df = pd.read_parquet(f"{DATA_DIR}/target_df.parquet") | |
| # Ensure column names match training | |
| df_learn = df_learn.rename(columns={ | |
| "disease_text_embed": "disease_text", | |
| "target_text_embed": "target_text" | |
| }, errors="ignore") | |
| disease_df.rename(columns={"disease_text_embed": "disease_text"}, errors="ignore",inplace=True) | |
| target_df.rename(columns={"target_text_embed":"target_text"}, errors="ignore",inplace=True) | |
| # ============================================ | |
| # BUILD MODEL + LOAD WEIGHTS | |
| # ============================================ | |
| print("Building TwoTowerDual...") | |
| # 1. Reset Keras Session to ensure layer names start at index 0 (matches clean training) | |
| tf.keras.backend.clear_session() | |
| # 2. Rebuild architecture | |
| model = build_two_tower_model(df_learn) | |
| print("Loading weights...") | |
| try: | |
| # Try standard load | |
| model.load_weights(WEIGHTS_FILE) | |
| except ValueError as e: | |
| print(f"Standard load failed ({e}). Attempting name-mismatch fix...") | |
| # FALLBACK: The training notebook likely generated layer names like 'dise_emb_1' | |
| # due to multiple runs. We inspect the .h5 file and map the names. | |
| with h5py.File(WEIGHTS_FILE, 'r') as f: | |
| h5_keys = list(f.keys()) | |
| print(f"Weights file contains layers: {h5_keys}") | |
| # Helper to find the matching key in h5 file for a given prefix | |
| def match_layer_name(target_attr, prefix): | |
| # Find key in h5 that starts with prefix (e.g. 'dise_emb') | |
| match = next((k for k in h5_keys if k.startswith(prefix)), None) | |
| if match and hasattr(model, target_attr): | |
| layer = getattr(model, target_attr) | |
| print(f"Renaming model layer '{layer.name}' to '{match}' to match file.") | |
| layer._name = match | |
| # Apply renames for known components | |
| match_layer_name('dise_emb', 'dise_emb') | |
| match_layer_name('q_tower', 'tower') # Attempt to catch tower/tower_1 | |
| # k_tower might share the name 'tower' prefix in H5, which is tricky in subclasses | |
| # usually save_weights on subclass saves attributes directly. | |
| # Retry load after renaming | |
| model.load_weights(WEIGHTS_FILE) | |
| print("Weights loaded successfully.") | |
| # ============================================ | |
| # PRECOMPUTE CANDIDATE EMBEDDINGS | |
| # ============================================ | |
| # # Note: In TF 2.16+, Ensure inputs are tf.constant or numpy compatible | |
| # cand_embs = model.encode_k(target_texts, target_ids) | |
| # cand_embs = tf.nn.l2_normalize(cand_embs, axis=1).numpy() | |
| # print("Candidate embeddings ready.") | |
| print("Precomputing candidate embeddings (batched)...") | |
| target_texts = target_df["target_text"].astype(str).to_numpy() | |
| target_ids = target_df["targetId"].astype(str).to_numpy() | |
| # FIX: Process in batches to avoid OOM | |
| BATCH_SIZE = 1024 # Conservative batch size for wide inputs | |
| cand_embs_list = [] | |
| total = len(target_texts) | |
| for i in range(0, total, BATCH_SIZE): | |
| # Slice the batch | |
| end = min(i + BATCH_SIZE, total) | |
| batch_txt = target_texts[i:end] | |
| batch_ids = target_ids[i:end] | |
| # Run inference on the batch (keeps memory usage low) | |
| # Using tf.device conversion is optional but good for safety if GPU is fragmented | |
| emb_batch = model.encode_k(batch_txt, batch_ids) | |
| cand_embs_list.append(emb_batch) | |
| if i % 5000 == 0: | |
| print(f" Processed {i}/{total} candidates...") | |
| # Concatenate all batches back into one tensor | |
| cand_embs = tf.concat(cand_embs_list, axis=0) | |
| # Normalize the final result | |
| cand_embs = tf.nn.l2_normalize(cand_embs, axis=1).numpy() | |
| print(f"Candidate embeddings ready. Shape: {cand_embs.shape}") | |
| # ============================================ | |
| # RECOMMENDATION FUNCTION | |
| # ============================================ | |
| def recommend_targets(disease_id, top_k=10): | |
| # 1. Validate Input | |
| if not disease_id: | |
| return pd.DataFrame(), None | |
| row = disease_df.loc[disease_df["diseaseId"] == disease_id] | |
| if row.empty: | |
| return pd.DataFrame(), None | |
| # 2. Encode Query | |
| disease_text = row["disease_text"].iloc[0] | |
| q_emb = model.encode_q( | |
| tf.constant([disease_text]), | |
| tf.constant([disease_id]) | |
| ) | |
| q_emb = tf.nn.l2_normalize(q_emb, axis=1).numpy()[0] | |
| # 3. Calculate Raw Cosine Similarity | |
| # Shape: (N_targets,) | |
| raw_sim = cand_embs @ q_emb | |
| # 4. Convert to Probability (Fixes negative scores) | |
| # The model has a trained 'cls_head' (Sigmoid) that maps Similarity -> Probability | |
| # We reshape to (N, 1) because the Keras Dense layer expects a matrix | |
| scores = model.cls_head(raw_sim.reshape(-1, 1)).numpy().flatten() | |
| # 5. Get Top K | |
| k = int(top_k) | |
| idx = np.argsort(scores)[::-1][:k] | |
| # 6. Build Result DataFrame | |
| results = target_df.iloc[idx].copy() | |
| # Force standard python float for clean rounding | |
| raw_scores = scores[idx] | |
| results["score"] = [round(float(x), 4) for x in raw_scores] | |
| # 7. Select Columns | |
| desc_col = "functionDescription" if "functionDescription" in results.columns else "functionDescriptions" | |
| desired_cols = [ | |
| "targetId", | |
| "approvedSymbol", | |
| "approvedName", | |
| desc_col, | |
| "score" | |
| ] | |
| final_cols = [c for c in desired_cols if c in results.columns] | |
| results = results[final_cols] | |
| # 8. Save to CSV for download | |
| csv_path = "recommendations.csv" | |
| results.to_csv(csv_path, index=False) | |
| return results, csv_path | |
| # ============================================ | |
| # GRADIO APP | |
| # ============================================ | |
| def search_diseases(query): | |
| if not query or len(query) < 2: | |
| return gr.update(choices=[], value=None) | |
| mask = ( | |
| disease_df["name"].str.contains(query, case=False, na=False) | | |
| disease_df["diseaseId"].str.contains(query, case=False, na=False) | |
| ) | |
| matches = disease_df.loc[mask].head(30) | |
| choices = [ | |
| (f"{row['name']} ({row['diseaseId']})", row['diseaseId']) | |
| for _, row in matches.iterrows() | |
| ] | |
| first_val = choices[0][1] if choices else None | |
| return gr.update(choices=choices, value=first_val) | |
| def launch(): | |
| examples = ["synuclein", "diabetes", "doid_0050890"] | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Disease → Target Recommender") | |
| gr.Markdown("Search for a disease by **Name** or **ID** to get target recommendations.") | |
| with gr.Row(): | |
| search_box = gr.Textbox( | |
| label="1. Search Disease", | |
| placeholder="Type name (e.g., 'Parkinson') or ID...", | |
| lines=1 | |
| ) | |
| did_dropdown = gr.Dropdown( | |
| label="2. Select Disease", | |
| choices=[], | |
| interactive=True | |
| ) | |
| topk = gr.Slider(1, 400, value=10, step=5, label="Top K Targets") | |
| # Search Logic (Updates dropdown options and default value) | |
| search_box.change(fn=search_diseases, inputs=search_box, outputs=did_dropdown) | |
| # Output Components (Stacked vertically for full width) | |
| out_df = gr.Dataframe( | |
| label="Predictions", | |
| interactive=False, | |
| wrap=True, | |
| show_search="filter", | |
| ) | |
| out_file = gr.File(label="Download CSV") | |
| # === TRIGGER LOGIC === | |
| # 1. Manual Trigger (Keep the button just in case) | |
| btn = gr.Button("Recommend Targets", variant="primary") | |
| btn.click( | |
| fn=recommend_targets, | |
| inputs=[did_dropdown, topk], | |
| outputs=[out_df, out_file] | |
| ) | |
| # 2. Auto-Trigger on Change | |
| # This handles the Examples too: Example -> Search -> Dropdown Update -> Trigger | |
| did_dropdown.change( | |
| fn=recommend_targets, | |
| inputs=[did_dropdown, topk], | |
| outputs=[out_df, out_file] | |
| ) | |
| # Also update when slider moves | |
| topk.change( | |
| fn=recommend_targets, | |
| inputs=[did_dropdown, topk], | |
| outputs=[out_df, out_file] | |
| ) | |
| gr.Examples(examples=examples, inputs=search_box) | |
| demo.launch() | |
| if __name__ == "__main__": | |
| launch() |