viswadarshan06 commited on
Commit
edcd131
·
verified ·
1 Parent(s): 243baad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -22
app.py CHANGED
@@ -7,16 +7,27 @@ import numpy as np
7
  from transformers import AutoTokenizer, AutoModel
8
  import torch
9
 
10
- # Set Hugging Face cache dir to avoid /.cache error
11
  os.environ["HF_HOME"] = "/app/hf_cache"
12
  os.environ["TRANSFORMERS_CACHE"] = "/app/hf_cache"
13
 
14
- # Initialize FastAPI app
 
 
 
15
  app = FastAPI()
16
 
17
- # Load multilingual model and tokenizer
18
- tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2", cache_dir="/app/hf_cache")
19
- model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2", cache_dir="/app/hf_cache")
 
 
 
 
 
 
 
 
20
 
21
  # Load Thirukkural data
22
  with open("thirukkural_data.pkl", "rb") as f:
@@ -26,41 +37,35 @@ with open("thirukkural_data.pkl", "rb") as f:
26
  english_index = faiss.read_index("thirukkural_english_index.faiss")
27
  tamil_index = faiss.read_index("thirukkural_tamil_index.faiss")
28
 
29
- # Input model for API
30
  class QueryRequest(BaseModel):
31
- query: str # user input
32
- lang: str # "en" or "ta"
33
- top_k: int = 3 # number of kurals to return
34
 
35
- # Mean pooling function
36
  def mean_pooling(model_output, attention_mask):
37
- token_embeddings = model_output[0] # first element: token embeddings
38
  input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
39
  return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
40
 
41
- # POST endpoint for retrieving relevant Thirukkural(s)
42
  @app.post("/search/")
43
  def search_kural(req: QueryRequest):
44
- # Tokenize input
45
  encoded_input = tokenizer(req.query, padding=True, truncation=True, return_tensors="pt")
46
-
47
- # Compute embeddings
48
  with torch.no_grad():
49
  model_output = model(**encoded_input)
50
  query_embedding = mean_pooling(model_output, encoded_input["attention_mask"])
51
-
52
- # Convert to numpy
53
  query_embedding = query_embedding.detach().cpu().numpy()
54
-
55
- # Choose FAISS index
56
  index = tamil_index if req.lang == "ta" else english_index
57
  D, I = index.search(query_embedding.astype("float32"), req.top_k)
58
 
59
- # Return top-k matching kurals
60
  results = [kural_data[i] for i in I[0]]
61
  return {"results": results}
62
 
63
- # Health check endpoint
64
  @app.get("/")
65
  def root():
66
- return {"message": "Thirukkural FastAPI RAG is running."}
 
7
  from transformers import AutoTokenizer, AutoModel
8
  import torch
9
 
10
+ # Set cache to custom dir to avoid /.cache issues
11
  os.environ["HF_HOME"] = "/app/hf_cache"
12
  os.environ["TRANSFORMERS_CACHE"] = "/app/hf_cache"
13
 
14
+ # Get Hugging Face token from environment (must be set in HF Space secrets)
15
+ hf_token = os.getenv("HF_TOKEN")
16
+
17
+ # Initialize FastAPI
18
  app = FastAPI()
19
 
20
+ # Load your private model
21
+ tokenizer = AutoTokenizer.from_pretrained(
22
+ "viswadarshan06/paraphrase-multilingual-MiniLM-L12-v2-local",
23
+ cache_dir="/app/hf_cache",
24
+ token=hf_token
25
+ )
26
+ model = AutoModel.from_pretrained(
27
+ "viswadarshan06/paraphrase-multilingual-MiniLM-L12-v2-local",
28
+ cache_dir="/app/hf_cache",
29
+ token=hf_token
30
+ )
31
 
32
  # Load Thirukkural data
33
  with open("thirukkural_data.pkl", "rb") as f:
 
37
  english_index = faiss.read_index("thirukkural_english_index.faiss")
38
  tamil_index = faiss.read_index("thirukkural_tamil_index.faiss")
39
 
40
+ # Request schema
41
  class QueryRequest(BaseModel):
42
+ query: str
43
+ lang: str # "en" or "ta"
44
+ top_k: int = 3
45
 
46
+ # Mean pooling (same as SentenceTransformer style)
47
  def mean_pooling(model_output, attention_mask):
48
+ token_embeddings = model_output[0]
49
  input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
50
  return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
51
 
52
+ # POST endpoint
53
  @app.post("/search/")
54
  def search_kural(req: QueryRequest):
 
55
  encoded_input = tokenizer(req.query, padding=True, truncation=True, return_tensors="pt")
56
+
 
57
  with torch.no_grad():
58
  model_output = model(**encoded_input)
59
  query_embedding = mean_pooling(model_output, encoded_input["attention_mask"])
60
+
 
61
  query_embedding = query_embedding.detach().cpu().numpy()
 
 
62
  index = tamil_index if req.lang == "ta" else english_index
63
  D, I = index.search(query_embedding.astype("float32"), req.top_k)
64
 
 
65
  results = [kural_data[i] for i in I[0]]
66
  return {"results": results}
67
 
68
+ # Health check
69
  @app.get("/")
70
  def root():
71
+ return {"message": "Thirukkural FastAPI RAG is running with private model."}