viswadarshan06 commited on
Commit
1ba33db
·
verified ·
1 Parent(s): 6e1dfb9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -41
app.py CHANGED
@@ -4,36 +4,18 @@ from pydantic import BaseModel
4
  import faiss
5
  import pickle
6
  import numpy as np
7
- from transformers import AutoTokenizer, AutoModel
8
  import torch
9
 
10
- # Set cache to custom dir to avoid /.cache issues
11
  os.environ["HF_HOME"] = "/app/hf_cache"
12
  os.environ["TRANSFORMERS_CACHE"] = "/app/hf_cache"
13
 
14
- # Get Hugging Face token from environment (must be set in HF Space secrets)
15
- hf_token = os.getenv("HF_TOKEN")
16
-
17
- # Initialize FastAPI
18
  app = FastAPI()
19
 
20
- import shutil
21
-
22
- model_cache_path = "/app/hf_cache/models--viswadarshan06--paraphrase-multilingual-MiniLM-L12-v2-local"
23
- if os.path.exists(model_cache_path):
24
- shutil.rmtree(model_cache_path)
25
-
26
- # ✅ Load your private model
27
- tokenizer = AutoTokenizer.from_pretrained(
28
- "viswadarshan06/paraphrase-multilingual-MiniLM-L12-v2-local",
29
- cache_dir="/app/hf_cache",
30
- token=hf_token
31
- )
32
- model = AutoModel.from_pretrained(
33
- "viswadarshan06/paraphrase-multilingual-MiniLM-L12-v2-local",
34
- cache_dir="/app/hf_cache",
35
- token=hf_token
36
- )
37
 
38
  # Load Thirukkural data
39
  with open("thirukkural_data.pkl", "rb") as f:
@@ -43,35 +25,23 @@ with open("thirukkural_data.pkl", "rb") as f:
43
  english_index = faiss.read_index("thirukkural_english_index.faiss")
44
  tamil_index = faiss.read_index("thirukkural_tamil_index.faiss")
45
 
46
- # Request schema
47
  class QueryRequest(BaseModel):
48
  query: str
49
  lang: str # "en" or "ta"
50
  top_k: int = 3
51
 
52
- # Mean pooling (same as SentenceTransformer style)
53
- def mean_pooling(model_output, attention_mask):
54
- token_embeddings = model_output[0]
55
- input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
56
- return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
57
-
58
- # POST endpoint
59
  @app.post("/search/")
60
  def search_kural(req: QueryRequest):
61
- encoded_input = tokenizer(req.query, padding=True, truncation=True, return_tensors="pt")
62
-
63
- with torch.no_grad():
64
- model_output = model(**encoded_input)
65
- query_embedding = mean_pooling(model_output, encoded_input["attention_mask"])
66
-
67
- query_embedding = query_embedding.detach().cpu().numpy()
68
  index = tamil_index if req.lang == "ta" else english_index
69
- D, I = index.search(query_embedding.astype("float32"), req.top_k)
70
 
71
  results = [kural_data[i] for i in I[0]]
72
  return {"results": results}
73
 
74
- # Health check
75
  @app.get("/")
76
  def root():
77
- return {"message": "Thirukkural FastAPI RAG is running with private model."}
 
4
  import faiss
5
  import pickle
6
  import numpy as np
7
+ from sentence_transformers import SentenceTransformer
8
  import torch
9
 
10
+ # Optional: set dummy cache paths if needed by other internal libs
11
  os.environ["HF_HOME"] = "/app/hf_cache"
12
  os.environ["TRANSFORMERS_CACHE"] = "/app/hf_cache"
13
 
14
+ # Initialize FastAPI app
 
 
 
15
  app = FastAPI()
16
 
17
+ # ✅ Load your locally uploaded SentenceTransformer model
18
+ model = SentenceTransformer("/app/model")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  # Load Thirukkural data
21
  with open("thirukkural_data.pkl", "rb") as f:
 
25
  english_index = faiss.read_index("thirukkural_english_index.faiss")
26
  tamil_index = faiss.read_index("thirukkural_tamil_index.faiss")
27
 
28
+ # Define request schema
29
  class QueryRequest(BaseModel):
30
  query: str
31
  lang: str # "en" or "ta"
32
  top_k: int = 3
33
 
34
+ # Search endpoint
 
 
 
 
 
 
35
  @app.post("/search/")
36
  def search_kural(req: QueryRequest):
37
+ query_embedding = model.encode([req.query])
 
 
 
 
 
 
38
  index = tamil_index if req.lang == "ta" else english_index
39
+ D, I = index.search(np.array(query_embedding).astype("float32"), req.top_k)
40
 
41
  results = [kural_data[i] for i in I[0]]
42
  return {"results": results}
43
 
44
+ # Health check endpoint
45
  @app.get("/")
46
  def root():
47
+ return {"message": "Thirukkural FastAPI RAG is running with local model."}