TutuAwad commited on
Commit
4f22c4e
·
verified ·
1 Parent(s): 0bde887

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -20
app.py CHANGED
@@ -18,6 +18,9 @@ from sentence_transformers import SentenceTransformer
18
  from huggingface_hub import InferenceClient
19
  import spotipy
20
  from spotipy.oauth2 import SpotifyClientCredentials
 
 
 
21
 
22
  # ---------- Paths to precomputed data ----------
23
 
@@ -49,16 +52,60 @@ print("Spotify secret present?", bool(SPOTIFY_CLIENT_SECRET))
49
  # Query encoder (same as notebook)
50
  query_embedder = SentenceTransformer("all-mpnet-base-v2")
51
 
52
- # LLaMA-2 for query expansion (remote HF Inference)
53
  LLAMA_MODEL_ID = "meta-llama/Llama-2-7b-chat-hf"
54
 
55
- hf_client = None
 
 
56
  if HF_TOKEN:
57
- try:
58
- hf_client = InferenceClient(token=HF_TOKEN)
59
- except Exception as e:
60
- print("⚠️ Could not initialize HF Inference client:", repr(e))
61
- hf_client = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  # Spotify client
64
  sp = None
@@ -82,12 +129,14 @@ def encode_query(text: str) -> np.ndarray:
82
 
83
  def expand_with_llama(query: str) -> str:
84
  """
85
- Enrich the query using LLaMA via HF Inference.
86
 
87
- If anything fails (no client, provider issues, rate limits, etc.),
88
- we log and fall back to the raw query so the app keeps working.
 
 
89
  """
90
- if hf_client is None or not HF_TOKEN:
91
  return query
92
 
93
  prompt = f"""You are helping someone search a lyrics catalog.
@@ -104,22 +153,42 @@ Input:
104
  Output (no explanation, just titles or keywords):"""
105
 
106
  try:
107
- response = hf_client.text_generation(
108
- prompt,
109
- model=LLAMA_MODEL_ID,
110
- max_new_tokens=96,
111
- temperature=0.2,
112
- repetition_penalty=1.05,
113
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  except Exception as e:
115
- print("⚠️ LLaMA expansion failed on HF, using raw query:", repr(e))
116
  return query
117
 
118
- keywords = str(response).strip().replace("\n", " ")
119
  expanded = query + " " + keywords
120
  return expanded
121
 
122
 
 
123
  def distances_to_similarity_pct(dists: np.ndarray) -> np.ndarray:
124
  if len(dists) == 0:
125
  return np.array([])
 
18
  from huggingface_hub import InferenceClient
19
  import spotipy
20
  from spotipy.oauth2 import SpotifyClientCredentials
21
+ import torch
22
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
23
+
24
 
25
  # ---------- Paths to precomputed data ----------
26
 
 
52
  # Query encoder (same as notebook)
53
  query_embedder = SentenceTransformer("all-mpnet-base-v2")
54
 
55
+ # LLaMA-2 for query expansion
56
  LLAMA_MODEL_ID = "meta-llama/Llama-2-7b-chat-hf"
57
 
58
+ llama_pipe = None # local quantized pipeline (preferred)
59
+ hf_client = None # hosted fallback
60
+
61
  if HF_TOKEN:
62
+ # Try to load a 4-bit quantized LLaMA locally (for HF Space with GPU)
63
+ if torch.cuda.is_available():
64
+ try:
65
+ print(" Loading LLaMA-2-7B in 4-bit NF4 with bitsandbytes...")
66
+ bnb_config = BitsAndBytesConfig(
67
+ load_in_4bit=True,
68
+ bnb_4bit_quant_type="nf4",
69
+ bnb_4bit_use_double_quant=True,
70
+ bnb_4bit_compute_dtype=torch.bfloat16,
71
+ )
72
+
73
+ llama_tokenizer = AutoTokenizer.from_pretrained(
74
+ LLAMA_MODEL_ID,
75
+ use_auth_token=HF_TOKEN,
76
+ )
77
+
78
+ llama_model = AutoModelForCausalLM.from_pretrained(
79
+ LLAMA_MODEL_ID,
80
+ quantization_config=bnb_config, # 🔑 this actually activates 4-bit
81
+ device_map="auto",
82
+ torch_dtype=torch.bfloat16,
83
+ use_auth_token=HF_TOKEN,
84
+ )
85
+
86
+ llama_pipe = pipeline(
87
+ "text-generation",
88
+ model=llama_model,
89
+ tokenizer=llama_tokenizer,
90
+ max_new_tokens=96,
91
+ temperature=0.2,
92
+ top_p=0.9,
93
+ repetition_penalty=1.05,
94
+ )
95
+ print(" Using local 4-bit quantized LLaMA backend.")
96
+ except Exception as e:
97
+ print("⚠️ Quantized LLaMA load failed, will try HF Inference fallback:", repr(e))
98
+
99
+ # If quantized local load failed (or no CUDA), fall back to HF hosted inference
100
+ if llama_pipe is None:
101
+ try:
102
+ hf_client = InferenceClient(model=LLAMA_MODEL_ID, token=HF_TOKEN)
103
+ print("✅ Using HF InferenceClient backend (hosted LLaMA).")
104
+ except Exception as e:
105
+ print("⚠️ Could not initialize any LLaMA backend:", repr(e))
106
+ else:
107
+ print("⚠️ No HF_TOKEN found; LLaMA expansion will be disabled.")
108
+
109
 
110
  # Spotify client
111
  sp = None
 
129
 
130
  def expand_with_llama(query: str) -> str:
131
  """
132
+ Enrich the query using LLaMA.
133
 
134
+ Priority:
135
+ 1) Use local 4-bit quantized LLaMA pipeline if available (HF Space with GPU).
136
+ 2) Otherwise, fall back to HF InferenceClient (hosted model).
137
+ 3) On any failure, return the raw query so the app keeps working.
138
  """
139
+ if not HF_TOKEN:
140
  return query
141
 
142
  prompt = f"""You are helping someone search a lyrics catalog.
 
153
  Output (no explanation, just titles or keywords):"""
154
 
155
  try:
156
+ if llama_pipe is not None:
157
+ # Local 4-bit quantized model on HF Space
158
+ outputs = llama_pipe(
159
+ prompt,
160
+ do_sample=True,
161
+ num_return_sequences=1,
162
+ )
163
+ full_text = outputs[0]["generated_text"]
164
+ # Strip the prompt off the front if it's included
165
+ if full_text.startswith(prompt):
166
+ keywords = full_text[len(prompt):].strip()
167
+ else:
168
+ keywords = full_text.strip()
169
+ elif hf_client is not None:
170
+ # Hosted HF Inference fallback
171
+ response = hf_client.text_generation(
172
+ prompt,
173
+ max_new_tokens=96,
174
+ temperature=0.2,
175
+ repetition_penalty=1.05,
176
+ )
177
+ keywords = str(response).strip()
178
+ else:
179
+ # No backend at all
180
+ return query
181
+
182
  except Exception as e:
183
+ print("⚠️ LLaMA expansion failed, using raw query:", repr(e))
184
  return query
185
 
186
+ keywords = keywords.replace("\n", " ")
187
  expanded = query + " " + keywords
188
  return expanded
189
 
190
 
191
+
192
  def distances_to_similarity_pct(dists: np.ndarray) -> np.ndarray:
193
  if len(dists) == 0:
194
  return np.array([])