HarmoniFind / app.py
TutuAwad's picture
Update app.py
4f22c4e verified
# -*- coding: utf-8 -*-
"""
HarmoniFind – Semantic Spotify Search
HF Spaces app.py
"""
import os
import random
from difflib import SequenceMatcher
import numpy as np
import pandas as pd
import faiss
import gradio as gr
import html as html_lib
from sentence_transformers import SentenceTransformer
from huggingface_hub import InferenceClient
import spotipy
from spotipy.oauth2 import SpotifyClientCredentials
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
# ---------- Paths to precomputed data ----------
CLEAN_CSV_PATH = "df_combined_clean.csv"
EMB_PATH = "df_embed.npz"
INDEX_PATH = "hnsw.index"
# ---------- Load data ----------
df_combined = pd.read_csv(CLEAN_CSV_PATH)
emb_data = np.load(EMB_PATH)
df_embeddings = emb_data["df_embeddings"].astype("float32")
index = faiss.read_index(INDEX_PATH)
# ---------- Secrets from env (HF Space secrets) ----------
HF_TOKEN = os.getenv("HF_TOKEN")
SPOTIFY_CLIENT_ID = os.getenv("SPOTIPY_CLIENT_ID")
SPOTIFY_CLIENT_SECRET = os.getenv("SPOTIPY_CLIENT_SECRET")
print("HF token present?", bool(HF_TOKEN))
print("Spotify ID present?", bool(SPOTIFY_CLIENT_ID))
print("Spotify secret present?", bool(SPOTIFY_CLIENT_SECRET))
# ---------- Models ----------
# Query encoder (same as notebook)
query_embedder = SentenceTransformer("all-mpnet-base-v2")
# LLaMA-2 for query expansion
LLAMA_MODEL_ID = "meta-llama/Llama-2-7b-chat-hf"
llama_pipe = None # local quantized pipeline (preferred)
hf_client = None # hosted fallback
if HF_TOKEN:
# Try to load a 4-bit quantized LLaMA locally (for HF Space with GPU)
if torch.cuda.is_available():
try:
print(" Loading LLaMA-2-7B in 4-bit NF4 with bitsandbytes...")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16,
)
llama_tokenizer = AutoTokenizer.from_pretrained(
LLAMA_MODEL_ID,
use_auth_token=HF_TOKEN,
)
llama_model = AutoModelForCausalLM.from_pretrained(
LLAMA_MODEL_ID,
quantization_config=bnb_config, # 🔑 this actually activates 4-bit
device_map="auto",
torch_dtype=torch.bfloat16,
use_auth_token=HF_TOKEN,
)
llama_pipe = pipeline(
"text-generation",
model=llama_model,
tokenizer=llama_tokenizer,
max_new_tokens=96,
temperature=0.2,
top_p=0.9,
repetition_penalty=1.05,
)
print(" Using local 4-bit quantized LLaMA backend.")
except Exception as e:
print("⚠️ Quantized LLaMA load failed, will try HF Inference fallback:", repr(e))
# If quantized local load failed (or no CUDA), fall back to HF hosted inference
if llama_pipe is None:
try:
hf_client = InferenceClient(model=LLAMA_MODEL_ID, token=HF_TOKEN)
print("✅ Using HF InferenceClient backend (hosted LLaMA).")
except Exception as e:
print("⚠️ Could not initialize any LLaMA backend:", repr(e))
else:
print("⚠️ No HF_TOKEN found; LLaMA expansion will be disabled.")
# Spotify client
sp = None
if SPOTIFY_CLIENT_ID and SPOTIFY_CLIENT_SECRET:
try:
auth = SpotifyClientCredentials(
client_id=SPOTIFY_CLIENT_ID,
client_secret=SPOTIFY_CLIENT_SECRET,
)
sp = spotipy.Spotify(auth_manager=auth)
except Exception as e:
print("⚠️ Could not initialize Spotify client:", repr(e))
sp = None
print("Spotify client created?", sp is not None)
# ---------- Core helpers ----------
def encode_query(text: str) -> np.ndarray:
return query_embedder.encode([text], convert_to_numpy=True).astype("float32")
def expand_with_llama(query: str) -> str:
"""
Enrich the query using LLaMA.
Priority:
1) Use local 4-bit quantized LLaMA pipeline if available (HF Space with GPU).
2) Otherwise, fall back to HF InferenceClient (hosted model).
3) On any failure, return the raw query so the app keeps working.
"""
if not HF_TOKEN:
return query
prompt = f"""You are helping someone search a lyrics catalog.
If the input looks like existing song lyrics or a singer name,
return artist and song titles that match.
Otherwise, return a short list of lyric-style keywords
that are closely related to the input sentence.
Input:
{query}
Output (no explanation, just titles or keywords):"""
try:
if llama_pipe is not None:
# Local 4-bit quantized model on HF Space
outputs = llama_pipe(
prompt,
do_sample=True,
num_return_sequences=1,
)
full_text = outputs[0]["generated_text"]
# Strip the prompt off the front if it's included
if full_text.startswith(prompt):
keywords = full_text[len(prompt):].strip()
else:
keywords = full_text.strip()
elif hf_client is not None:
# Hosted HF Inference fallback
response = hf_client.text_generation(
prompt,
max_new_tokens=96,
temperature=0.2,
repetition_penalty=1.05,
)
keywords = str(response).strip()
else:
# No backend at all
return query
except Exception as e:
print("⚠️ LLaMA expansion failed, using raw query:", repr(e))
return query
keywords = keywords.replace("\n", " ")
expanded = query + " " + keywords
return expanded
def distances_to_similarity_pct(dists: np.ndarray) -> np.ndarray:
if len(dists) == 0:
return np.array([])
dmin, dmax = dists.min(), dists.max()
if dmax - dmin == 0:
return np.ones_like(dists) * 100
sims = 100 * (1 - (dists - dmin) / (dmax - dmin))
return sims
def label_vibes(sim: float) -> str:
if sim >= 90:
return "dead-on"
elif sim >= 80:
return "strong vibes"
elif sim >= 70:
return "adjacent"
elif sim >= 60:
return "stretch but related"
else:
return "pretty random"
def semantic_search(query: str, k: int = 10, random_extra: int = 0, use_llama: bool = True) -> pd.DataFrame:
if not query or not query.strip():
return pd.DataFrame(columns=["artist", "song", "similarity_pct", "vibes", "is_random"])
q_text = expand_with_llama(query) if use_llama else query
q_vec = encode_query(q_text)
dists, idxs = index.search(q_vec, k)
sem_df = df_combined.iloc[idxs[0]].copy()
sem_df["similarity_pct"] = distances_to_similarity_pct(dists[0])
sem_df["vibes"] = sem_df["similarity_pct"].apply(label_vibes)
sem_df["is_random"] = False
rand_df = pd.DataFrame()
if random_extra > 0:
chosen = np.random.choice(
len(df_combined),
size=min(random_extra, len(df_combined)),
replace=False,
)
rand_df = df_combined.iloc[chosen].copy()
rand_df["similarity_pct"] = np.nan
rand_df["vibes"] = "pure random"
rand_df["is_random"] = True
results = pd.concat([sem_df, rand_df], ignore_index=True)
return results
def lookup_spotify_track_smart(artist: str, song: str):
if not sp:
return None, None
q = f"track:{song} artist:{artist}"
try:
results = sp.search(q, type="track", limit=3)
items = results.get("tracks", {}).get("items", [])
if not items:
return None, None
best = max(
items,
key=lambda t: SequenceMatcher(None, t["name"].lower(), song.lower()).ratio(),
)
url = best["external_urls"]["spotify"]
images = best["album"]["images"]
img_url = images[0]["url"] if images else None
return url, img_url
except Exception as e:
print("⚠️ Spotify search failed:", repr(e))
return None, None
def search_pipeline(query: str, k: int = 10, random_extra: int = 0, use_llama: bool = True) -> pd.DataFrame:
res = semantic_search(query, k, random_extra, use_llama)
if res.empty or sp is None:
res["spotify_url"], res["album_image"] = None, None
return res
urls, imgs = [], []
for _, r in res.iterrows():
u, i = lookup_spotify_track_smart(str(r["artist"]), str(r["song"]))
urls.append(u)
imgs.append(i)
res["spotify_url"], res["album_image"] = urls, imgs
return res
def get_random_vibe() -> str:
topics = ["late-night drives", "dog bloopers", "breakups", "sunset beaches", "college nostalgia"]
perspectives = ["first-person", "third-person", "group", "inner monologue"]
tones = ["dreamy", "chaotic", "romantic", "melancholic"]
return f"Lyrics about {random.choice(topics)}, told in {random.choice(perspectives)}, {random.choice(tones)}."
# ---------- CSS ----------
app_css = """
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;800;900&display=swap');
/* Shell + base uses CSS variables so we can change bg from Python */
body, .gradio-container {
background: radial-gradient(
circle at 50% 0%,
var(--hf-bg-top, #1e293b),
var(--hf-bg-bottom, #020617) 80%
) !important;
font-family: 'Inter', system-ui, -apple-system, BlinkMacSystemFont, "Segoe UI", sans-serif !important;
color: #e5e7eb;
}
.gradio-container .block {
background: transparent !important;
border: none !important;
box-shadow: none !important;
}
/* Inputs */
.gradio-container input,
.gradio-container textarea {
background: rgba(15,23,42,0.8) !important;
border: 1px solid rgba(148,163,184,0.6) !important;
color: #f9fafb !important;
border-radius: 12px !important;
font-size: 0.95rem !important;
transition: all 0.18s ease;
}
.gradio-container input:focus,
.gradio-container textarea:focus {
border-color: #10b981 !important;
box-shadow: 0 0 0 2px rgba(16,185,129,0.3) !important;
}
/* Buttons */
button.primary-btn {
background: linear-gradient(135deg,#10b981,#059669) !important;
border: none !important;
color: #ecfdf5 !important;
font-weight: 700 !important;
border-radius: 999px !important;
padding-inline: 18px !important;
}
button.primary-btn:hover {
transform: translateY(-1px);
box-shadow: 0 10px 22px -8px rgba(16,185,129,0.6);
}
button.secondary-btn {
background: rgba(15,23,42,0.9) !important;
color: #cbd5f5 !important;
border-radius: 999px !important;
border: 1px solid rgba(148,163,184,0.8) !important;
padding-inline: 14px !important;
}
button.secondary-btn:hover {
background: rgba(30,64,175,0.9) !important;
}
/* Top shell + header */
#hf-shell {
max-width: 960px;
margin: 0 auto;
padding: 24px 12px 40px;
}
#lux-header {
text-align: left;
padding: 14px 4px 8px;
}
#lux-header h1 {
font-size: 2.4rem;
font-weight: 900;
background: linear-gradient(to right,#f9fafb,#9ca3af);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
margin: 0;
letter-spacing: -0.06em;
}
.lux-subline {
text-transform: uppercase;
letter-spacing: 0.20em;
font-size: 0.75rem;
color: #10b981;
margin-bottom: 6px;
font-weight: 600;
}
#lux-header p {
color: #9ca3af;
font-size: 0.9rem;
margin-top: 8px;
}
/* Meta row for tracks + copy-link */
.lux-meta {
display:flex;
flex-wrap:wrap;
gap:8px;
margin-top:8px;
align-items:center;
font-size:0.8rem;
color:#e5e7eb;
}
.lux-badge {
font-size: 0.75rem;
padding: 6px 12px;
border-radius: 999px;
border: 1px solid rgba(148,163,184,0.6);
text-transform: uppercase;
letter-spacing: 0.12em;
}
.lux-pill {
font-size: 0.75rem;
padding: 6px 12px;
border-radius: 999px;
border: 1px solid rgba(148,163,184,0.6);
background: rgba(255,255,255,0.04);
text-decoration:none;
color:#e5e7eb;
}
.lux-pill:hover {
background: rgba(255,255,255,0.08);
}
/* Playlist wrapper + cards */
#lux-wrapper {
max-width: 960px;
margin: 0 auto;
padding: 24px 12px 40px;
}
.lux-playlist-wrapper {
margin-top: 12px;
display: flex;
flex-direction: column;
gap: 10px;
}
.lux-card {
display: flex;
gap: 14px;
padding: 12px 14px;
border-radius: 18px;
background: rgba(15,23,42,0.94);
border: 1px solid rgba(148,163,184,0.22);
}
.lux-cover {
width: 72px;
height: 72px;
border-radius: 14px;
overflow: hidden;
background: #020617;
flex-shrink: 0;
display:flex;
align-items:center;
justify-content:center;
color:#6b7280;
font-size: 20px;
}
.lux-cover img {
width: 100%;
height: 100%;
object-fit: cover;
}
.lux-main {
flex: 1;
display: flex;
flex-direction: column;
gap: 4px;
min-width: 0;
}
.lux-title-row {
display: flex;
justify-content: space-between;
gap: 8px;
align-items: flex-start;
}
.lux-title {
font-size: 0.95rem;
font-weight: 600;
color: #e5e7eb;
white-space: nowrap;
overflow: hidden;
text-overflow: ellipsis;
}
.lux-artist {
font-size: 0.8rem;
color: #9ca3af;
}
.lux-score {
display: flex;
flex-direction: column;
align-items: flex-end;
gap: 4px;
}
.lux-score-badge {
font-size: 0.7rem;
padding: 3px 8px;
border-radius: 999px;
background: rgba(34,197,94,0.14);
color: #bbf7d0;
}
.lux-vibes {
font-size: 0.7rem;
color: #9ca3af;
}
.lux-bottom-row {
display: flex;
justify-content: space-between;
align-items: center;
gap: 8px;
margin-top: 2px;
}
.lux-play-btn {
display:inline-flex;
align-items:center;
gap:6px;
padding:7px 12px;
border-radius:999px;
background:#22c55e;
color:#022c22;
font-size:0.8rem;
font-weight:600;
text-decoration:none;
}
.lux-chip {
font-size:0.65rem;
border-radius:999px;
padding:3px 7px;
background:rgba(148,163,184,0.18);
color:#e5e7eb;
}
"""
# ---------- Background palette + helper ----------
BG_PALETTE = [
("#1e293b", "#020617"),
("#0f172a", "#020617"),
("#0b1120", "#020617"),
("#111827", "#020617"),
("#1f2937", "#020617"),
]
def make_bg_style_html() -> str:
"""Pick a gradient pair from the palette and emit a <style> that sets CSS vars."""
top, bottom = random.choice(BG_PALETTE)
return f"<style>:root {{ --hf-bg-top: {top}; --hf-bg-bottom: {bottom}; }}</style>"
# ---------- Theming + helpers ----------
def infer_theme(query: str):
q = (query or "").lower()
if any(w in q for w in ["night", "drive", "highway", "city", "neon"]):
return {"name": "Midnight Drive", "emoji": "🌃"}
if any(w in q for w in ["party", "dance", "club", "crowd", "festival"]):
return {"name": "Nightclub Neon", "emoji": "🎉"}
if any(w in q for w in ["shower", "bathroom", "mirror", "getting ready"]):
return {"name": "Mirror Concert", "emoji": "🚿"}
if any(w in q for w in ["dog", "pet", "cat", "bloopers"]):
return {"name": "Pet Bloopers", "emoji": "🐶"}
# default
return {"name": "", "emoji": "🎧"}
# ---------- DataFrame -> HTML ----------
def results_to_lux_html(results: pd.DataFrame, query: str) -> str:
if results is None or results.empty:
return """
<div id="lux-wrapper">
<div id="lux-header">
<div class="lux-subline">HarmoniFind • Semantic playlist</div>
<h1>🎧 Describe a vibe to start</h1>
<p style="font-size:0.9rem;color:rgba(156,163,175,0.95);margin-top:8px;">
Type a brief above, or click <strong>🎲</strong> for a fun prompt.
</p>
</div>
</div>
"""
theme = infer_theme(query)
query_safe = html_lib.escape(query or "")
emoji = theme["emoji"]
cards_html = ""
tracks_plain = []
for _, row in results.iterrows():
raw_artist = str(row.get("artist", ""))
raw_song = str(row.get("song", ""))
artist = html_lib.escape(raw_artist)
song = html_lib.escape(raw_song)
# for clipboard list
tracks_plain.append(f"{raw_song}{raw_artist}")
is_random = bool(row.get("is_random", False))
sim_pct = row.get("similarity_pct", None)
if pd.isna(sim_pct) or is_random:
sim_display = "—"
score_bg = "rgba(148,163,184,0.2)"
vibes = "pure random"
else:
sim_display = f"{float(sim_pct):.1f}%"
score_bg = "rgba(34,197,94,0.14)"
vibes = html_lib.escape(str(row.get("vibes", "")))
url = row.get("spotify_url", None)
img = row.get("album_image", None)
if isinstance(img, str) and img:
cover = f'<div class="lux-cover"><img src="{html_lib.escape(img)}"></div>'
else:
cover = '<div class="lux-cover">♪</div>'
if isinstance(url, str) and url:
play_btn = f'<a class="lux-play-btn" href="{html_lib.escape(url)}" target="_blank">▶︎ Play on Spotify</a>'
else:
play_btn = ""
random_chip = ""
if is_random:
random_chip = '<span class="lux-chip">🎲 random pick</span>'
cards_html += f"""
<div class="lux-card">
{cover}
<div class="lux-main">
<div class="lux-title-row">
<div>
<div class="lux-title">{song}</div>
<div class="lux-artist">{artist}</div>
</div>
<div class="lux-score">
<div class="lux-score-badge" style="background:{score_bg};">{sim_display}</div>
<div class="lux-vibes">{vibes}</div>
</div>
</div>
<div class="lux-bottom-row">
{play_btn}
{random_chip}
</div>
</div>
</div>
"""
# Build track list text for clipboard
header_line = f"HarmoniFind results for: {query or ''}".strip()
if not header_line:
header_line = "HarmoniFind results"
list_text = header_line + "\n\n" + "\n".join(tracks_plain)
# escape for JS string
js_text = (
list_text
.replace("\\", "\\\\")
.replace("'", "\\'")
.replace("\n", "\\n")
)
meta_html = f"""
<p>Semantic matches first, plus optional 🎲 discovery if you enabled it.</p>
<div class="lux-meta">
<span class="lux-badge">Tracks: {len(results)}</span>
<a class="lux-pill" href="javascript:void(0);" onclick="navigator.clipboard.writeText('{js_text}');">
🔗 Copy Your HarmoniFinds
</a>
</div>
"""
html = f"""
<div id="lux-wrapper">
<div id="lux-header">
<div class="lux-subline">HarmoniFind • Semantic playlist</div>
<h1>{emoji} {query_safe or "Untitled vibe"}</h1>
{meta_html}
</div>
<div class="lux-playlist-wrapper">
{cards_html}
</div>
</div>
"""
return html
# ---------- Search + bg wrapper ----------
def core_search_html(query, k, random_extra):
# LLaMA expansion always ON now
results = search_pipeline(
query=query or "",
k=int(k),
random_extra=int(random_extra),
use_llama=True,
)
return results_to_lux_html(results, query or "")
def search_with_bg(query, k, random_extra):
"""Return playlist HTML + a new background style snippet."""
playlist_html = core_search_html(query, k, random_extra)
bg_style_html = make_bg_style_html()
return playlist_html, bg_style_html
def surprise_brief():
return get_random_vibe()
def clear_all():
# reset query, results (empty state), and bg
empty_html = results_to_lux_html(None, "")
return "", empty_html, make_bg_style_html()
# ---------- Gradio UI ----------
with gr.Blocks(title="HarmoniFind") as demo:
# Inject CSS manually (HF Gradio version may not support css=... kwarg)
gr.HTML(f"<style>{app_css}</style>")
# dynamic bg style holder (updated on each search)
bg_style = gr.HTML(make_bg_style_html())
# Header
gr.HTML("""
<div id="hf-shell">
<div id="lux-header">
<div class="lux-subline">HARMONIFIND • LYRICS-DRIVEN SEMANTIC SEARCH</div>
<h1>Describe Your Song.</h1>
<p>We search by what the lyrics <strong>mean</strong>, not just titles or genres.</p>
</div>
</div>
""")
with gr.Column():
# Textbox + stacked buttons on the right
with gr.Row(variant="compact"):
input_box = gr.Textbox(
placeholder="Lyrics about a carefree road trip with too many snack stops",
show_label=False,
lines=3,
scale=5,
)
with gr.Column(scale=2, min_width=160):
search_btn = gr.Button("Search", elem_classes=["primary-btn"])
surprise_btn = gr.Button("🎲 Surprise me", elem_classes=["secondary-btn"])
clear_btn = gr.Button("Clear", elem_classes=["secondary-btn"])
# Sliders only (LLaMA is always on; no checkbox)
with gr.Accordion("Search settings", open=False):
with gr.Row():
k_slider = gr.Slider(5, 50, value=10, step=1, label="# semantic matches")
rand_slider = gr.Slider(0, 10, value=2, step=1, label="# extra random tracks")
output_html = gr.HTML()
# Search updates playlist + bg
input_box.submit(
search_with_bg,
[input_box, k_slider, rand_slider],
[output_html, bg_style],
)
search_btn.click(
search_with_bg,
[input_box, k_slider, rand_slider],
[output_html, bg_style],
)
# Surprise: fill box, then search + bg
surprise_btn.click(
surprise_brief,
outputs=input_box,
).then(
search_with_bg,
[input_box, k_slider, rand_slider],
[output_html, bg_style],
)
# Clear
clear_btn.click(
clear_all,
None,
[input_box, output_html, bg_style],
)
if __name__ == "__main__":
port = int(os.getenv("PORT", 7860))
demo.launch(server_name="0.0.0.0", server_port=port)