Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -18,6 +18,9 @@ from sentence_transformers import SentenceTransformer
|
|
| 18 |
from huggingface_hub import InferenceClient
|
| 19 |
import spotipy
|
| 20 |
from spotipy.oauth2 import SpotifyClientCredentials
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
# ---------- Paths to precomputed data ----------
|
| 23 |
|
|
@@ -49,16 +52,60 @@ print("Spotify secret present?", bool(SPOTIFY_CLIENT_SECRET))
|
|
| 49 |
# Query encoder (same as notebook)
|
| 50 |
query_embedder = SentenceTransformer("all-mpnet-base-v2")
|
| 51 |
|
| 52 |
-
# LLaMA-2 for query expansion
|
| 53 |
LLAMA_MODEL_ID = "meta-llama/Llama-2-7b-chat-hf"
|
| 54 |
|
| 55 |
-
|
|
|
|
|
|
|
| 56 |
if HF_TOKEN:
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
# Spotify client
|
| 64 |
sp = None
|
|
@@ -82,12 +129,14 @@ def encode_query(text: str) -> np.ndarray:
|
|
| 82 |
|
| 83 |
def expand_with_llama(query: str) -> str:
|
| 84 |
"""
|
| 85 |
-
Enrich the query using LLaMA
|
| 86 |
|
| 87 |
-
|
| 88 |
-
|
|
|
|
|
|
|
| 89 |
"""
|
| 90 |
-
if
|
| 91 |
return query
|
| 92 |
|
| 93 |
prompt = f"""You are helping someone search a lyrics catalog.
|
|
@@ -104,22 +153,42 @@ Input:
|
|
| 104 |
Output (no explanation, just titles or keywords):"""
|
| 105 |
|
| 106 |
try:
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
except Exception as e:
|
| 115 |
-
print("⚠️ LLaMA expansion failed
|
| 116 |
return query
|
| 117 |
|
| 118 |
-
keywords =
|
| 119 |
expanded = query + " " + keywords
|
| 120 |
return expanded
|
| 121 |
|
| 122 |
|
|
|
|
| 123 |
def distances_to_similarity_pct(dists: np.ndarray) -> np.ndarray:
|
| 124 |
if len(dists) == 0:
|
| 125 |
return np.array([])
|
|
|
|
| 18 |
from huggingface_hub import InferenceClient
|
| 19 |
import spotipy
|
| 20 |
from spotipy.oauth2 import SpotifyClientCredentials
|
| 21 |
+
import torch
|
| 22 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
|
| 23 |
+
|
| 24 |
|
| 25 |
# ---------- Paths to precomputed data ----------
|
| 26 |
|
|
|
|
| 52 |
# Query encoder (same as notebook)
|
| 53 |
query_embedder = SentenceTransformer("all-mpnet-base-v2")
|
| 54 |
|
| 55 |
+
# LLaMA-2 for query expansion
|
| 56 |
LLAMA_MODEL_ID = "meta-llama/Llama-2-7b-chat-hf"
|
| 57 |
|
| 58 |
+
llama_pipe = None # local quantized pipeline (preferred)
|
| 59 |
+
hf_client = None # hosted fallback
|
| 60 |
+
|
| 61 |
if HF_TOKEN:
|
| 62 |
+
# Try to load a 4-bit quantized LLaMA locally (for HF Space with GPU)
|
| 63 |
+
if torch.cuda.is_available():
|
| 64 |
+
try:
|
| 65 |
+
print(" Loading LLaMA-2-7B in 4-bit NF4 with bitsandbytes...")
|
| 66 |
+
bnb_config = BitsAndBytesConfig(
|
| 67 |
+
load_in_4bit=True,
|
| 68 |
+
bnb_4bit_quant_type="nf4",
|
| 69 |
+
bnb_4bit_use_double_quant=True,
|
| 70 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
llama_tokenizer = AutoTokenizer.from_pretrained(
|
| 74 |
+
LLAMA_MODEL_ID,
|
| 75 |
+
use_auth_token=HF_TOKEN,
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
llama_model = AutoModelForCausalLM.from_pretrained(
|
| 79 |
+
LLAMA_MODEL_ID,
|
| 80 |
+
quantization_config=bnb_config, # 🔑 this actually activates 4-bit
|
| 81 |
+
device_map="auto",
|
| 82 |
+
torch_dtype=torch.bfloat16,
|
| 83 |
+
use_auth_token=HF_TOKEN,
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
llama_pipe = pipeline(
|
| 87 |
+
"text-generation",
|
| 88 |
+
model=llama_model,
|
| 89 |
+
tokenizer=llama_tokenizer,
|
| 90 |
+
max_new_tokens=96,
|
| 91 |
+
temperature=0.2,
|
| 92 |
+
top_p=0.9,
|
| 93 |
+
repetition_penalty=1.05,
|
| 94 |
+
)
|
| 95 |
+
print(" Using local 4-bit quantized LLaMA backend.")
|
| 96 |
+
except Exception as e:
|
| 97 |
+
print("⚠️ Quantized LLaMA load failed, will try HF Inference fallback:", repr(e))
|
| 98 |
+
|
| 99 |
+
# If quantized local load failed (or no CUDA), fall back to HF hosted inference
|
| 100 |
+
if llama_pipe is None:
|
| 101 |
+
try:
|
| 102 |
+
hf_client = InferenceClient(model=LLAMA_MODEL_ID, token=HF_TOKEN)
|
| 103 |
+
print("✅ Using HF InferenceClient backend (hosted LLaMA).")
|
| 104 |
+
except Exception as e:
|
| 105 |
+
print("⚠️ Could not initialize any LLaMA backend:", repr(e))
|
| 106 |
+
else:
|
| 107 |
+
print("⚠️ No HF_TOKEN found; LLaMA expansion will be disabled.")
|
| 108 |
+
|
| 109 |
|
| 110 |
# Spotify client
|
| 111 |
sp = None
|
|
|
|
| 129 |
|
| 130 |
def expand_with_llama(query: str) -> str:
|
| 131 |
"""
|
| 132 |
+
Enrich the query using LLaMA.
|
| 133 |
|
| 134 |
+
Priority:
|
| 135 |
+
1) Use local 4-bit quantized LLaMA pipeline if available (HF Space with GPU).
|
| 136 |
+
2) Otherwise, fall back to HF InferenceClient (hosted model).
|
| 137 |
+
3) On any failure, return the raw query so the app keeps working.
|
| 138 |
"""
|
| 139 |
+
if not HF_TOKEN:
|
| 140 |
return query
|
| 141 |
|
| 142 |
prompt = f"""You are helping someone search a lyrics catalog.
|
|
|
|
| 153 |
Output (no explanation, just titles or keywords):"""
|
| 154 |
|
| 155 |
try:
|
| 156 |
+
if llama_pipe is not None:
|
| 157 |
+
# Local 4-bit quantized model on HF Space
|
| 158 |
+
outputs = llama_pipe(
|
| 159 |
+
prompt,
|
| 160 |
+
do_sample=True,
|
| 161 |
+
num_return_sequences=1,
|
| 162 |
+
)
|
| 163 |
+
full_text = outputs[0]["generated_text"]
|
| 164 |
+
# Strip the prompt off the front if it's included
|
| 165 |
+
if full_text.startswith(prompt):
|
| 166 |
+
keywords = full_text[len(prompt):].strip()
|
| 167 |
+
else:
|
| 168 |
+
keywords = full_text.strip()
|
| 169 |
+
elif hf_client is not None:
|
| 170 |
+
# Hosted HF Inference fallback
|
| 171 |
+
response = hf_client.text_generation(
|
| 172 |
+
prompt,
|
| 173 |
+
max_new_tokens=96,
|
| 174 |
+
temperature=0.2,
|
| 175 |
+
repetition_penalty=1.05,
|
| 176 |
+
)
|
| 177 |
+
keywords = str(response).strip()
|
| 178 |
+
else:
|
| 179 |
+
# No backend at all
|
| 180 |
+
return query
|
| 181 |
+
|
| 182 |
except Exception as e:
|
| 183 |
+
print("⚠️ LLaMA expansion failed, using raw query:", repr(e))
|
| 184 |
return query
|
| 185 |
|
| 186 |
+
keywords = keywords.replace("\n", " ")
|
| 187 |
expanded = query + " " + keywords
|
| 188 |
return expanded
|
| 189 |
|
| 190 |
|
| 191 |
+
|
| 192 |
def distances_to_similarity_pct(dists: np.ndarray) -> np.ndarray:
|
| 193 |
if len(dists) == 0:
|
| 194 |
return np.array([])
|