Spaces:
Sleeping
Sleeping
Commit
·
00462bb
1
Parent(s):
d57ffbb
final
Browse files
app.py
CHANGED
|
@@ -3,6 +3,7 @@ import logging
|
|
| 3 |
from typing import Optional, List
|
| 4 |
import asyncio
|
| 5 |
from contextlib import asynccontextmanager
|
|
|
|
| 6 |
|
| 7 |
import pymongo
|
| 8 |
from fastapi import FastAPI, HTTPException
|
|
@@ -17,9 +18,7 @@ from langchain_core.messages import (
|
|
| 17 |
)
|
| 18 |
from langchain_core.prompts import PromptTemplate
|
| 19 |
from langchain_groq import ChatGroq
|
| 20 |
-
from langchain_core.output_parsers import
|
| 21 |
-
from langchain_core.chat_history import BaseChatMessageHistory
|
| 22 |
-
from langchain_core.runnables.history import RunnableWithMessageHistory
|
| 23 |
|
| 24 |
# -------------------------------------------------------------------
|
| 25 |
# CONFIGURATION
|
|
@@ -46,7 +45,7 @@ history_collection = None
|
|
| 46 |
# -------------------------------------------------------------------
|
| 47 |
# CHAT HISTORY HANDLER (MongoDB)
|
| 48 |
# -------------------------------------------------------------------
|
| 49 |
-
class MongoDBChatMessageHistory
|
| 50 |
"""Chat message history stored in MongoDB."""
|
| 51 |
|
| 52 |
def __init__(self, collection, session_id: str):
|
|
@@ -55,19 +54,18 @@ class MongoDBChatMessageHistory(BaseChatMessageHistory):
|
|
| 55 |
self.collection.create_index("session_id")
|
| 56 |
|
| 57 |
@property
|
| 58 |
-
def messages(self) -> List[
|
| 59 |
"""Retrieve messages from MongoDB"""
|
| 60 |
document = self.collection.find_one({"session_id": self.session_id})
|
| 61 |
if document and "history" in document:
|
| 62 |
-
return
|
| 63 |
return []
|
| 64 |
|
| 65 |
-
def add_message(self, message:
|
| 66 |
"""Append the message to MongoDB history"""
|
| 67 |
-
message_dict = message_to_dict(message)
|
| 68 |
self.collection.update_one(
|
| 69 |
{"session_id": self.session_id},
|
| 70 |
-
{"$push": {"history":
|
| 71 |
upsert=True,
|
| 72 |
)
|
| 73 |
|
|
@@ -86,9 +84,9 @@ async def lifespan(app: FastAPI):
|
|
| 86 |
# Initialize LLM
|
| 87 |
llm = ChatGroq(
|
| 88 |
groq_api_key=GROQ_API_KEY,
|
| 89 |
-
model_name="
|
| 90 |
-
temperature=0,
|
| 91 |
-
max_tokens=
|
| 92 |
)
|
| 93 |
logger.info("ChatGroq LLM initialized successfully")
|
| 94 |
|
|
@@ -112,8 +110,8 @@ async def lifespan(app: FastAPI):
|
|
| 112 |
# -------------------------------------------------------------------
|
| 113 |
app = FastAPI(
|
| 114 |
title="India Legal Consultation Service",
|
| 115 |
-
description="A bilingual
|
| 116 |
-
version="
|
| 117 |
lifespan=lifespan
|
| 118 |
)
|
| 119 |
|
|
@@ -129,19 +127,13 @@ app.add_middleware(
|
|
| 129 |
# -------------------------------------------------------------------
|
| 130 |
# PYDANTIC MODELS
|
| 131 |
# -------------------------------------------------------------------
|
| 132 |
-
class ConsultationResponse(BaseModel):
|
| 133 |
-
consultation: str
|
| 134 |
-
key_terms: Optional[str] = "general"
|
| 135 |
-
|
| 136 |
-
consultation_parser = PydanticOutputParser(pydantic_object=ConsultationResponse)
|
| 137 |
-
|
| 138 |
class QueryRequest(BaseModel):
|
| 139 |
query: str
|
| 140 |
user_id: str
|
| 141 |
|
| 142 |
-
class
|
| 143 |
response: str
|
| 144 |
-
history: List[
|
| 145 |
|
| 146 |
class HistoryResponse(BaseModel):
|
| 147 |
history: List[dict] = Field(default_factory=list)
|
|
@@ -152,60 +144,67 @@ class DeleteResponse(BaseModel):
|
|
| 152 |
|
| 153 |
|
| 154 |
# -------------------------------------------------------------------
|
| 155 |
-
# LEGAL CONSULTATION PROMPT (FIXED - NO
|
| 156 |
# -------------------------------------------------------------------
|
| 157 |
consultation_prompt = PromptTemplate(
|
| 158 |
input_variables=["query", "previous_messages"],
|
| 159 |
-
template="""You are an expert Indian legal advisor
|
| 160 |
-
partial_variables={"format_instructions": consultation_parser.get_format_instructions()},
|
| 161 |
-
)
|
| 162 |
|
|
|
|
| 163 |
|
| 164 |
-
|
| 165 |
-
# FUNCTIONS TO BUILD CHAINS
|
| 166 |
-
# -------------------------------------------------------------------
|
| 167 |
-
def get_consultation_chain():
|
| 168 |
-
if llm is None:
|
| 169 |
-
raise HTTPException(status_code=503, detail="LLM not initialized yet")
|
| 170 |
-
return consultation_prompt | llm | StrOutputParser()
|
| 171 |
-
|
| 172 |
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
)
|
| 181 |
|
|
|
|
|
|
|
| 182 |
|
| 183 |
# -------------------------------------------------------------------
|
| 184 |
# API ROUTES
|
| 185 |
# -------------------------------------------------------------------
|
| 186 |
-
@app.post("/legal-consultation", response_model=
|
| 187 |
def legal_consultation(request: QueryRequest):
|
| 188 |
if llm is None or history_collection is None:
|
| 189 |
raise HTTPException(status_code=503, detail="Service not ready. Try again later.")
|
| 190 |
|
| 191 |
try:
|
| 192 |
-
consultation_runnable = get_consultation_runnable()
|
| 193 |
-
|
| 194 |
-
result = consultation_runnable.invoke(
|
| 195 |
-
{"query": request.query},
|
| 196 |
-
{"configurable": {"session_id": request.user_id}}
|
| 197 |
-
)
|
| 198 |
-
|
| 199 |
history_obj = MongoDBChatMessageHistory(history_collection, request.user_id)
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
response=result,
|
| 204 |
-
history=
|
| 205 |
)
|
| 206 |
except Exception as e:
|
| 207 |
logger.error(f"Consultation failed: {e}")
|
| 208 |
-
raise HTTPException(status_code=500, detail=f"Consultation failed: {e}")
|
| 209 |
|
| 210 |
|
| 211 |
@app.get("/get-session-history/{session_id}", response_model=HistoryResponse)
|
|
@@ -215,11 +214,10 @@ async def get_session_history(session_id: str):
|
|
| 215 |
try:
|
| 216 |
history_obj = MongoDBChatMessageHistory(history_collection, session_id)
|
| 217 |
messages = history_obj.messages
|
| 218 |
-
|
| 219 |
-
return HistoryResponse(history=history_dicts)
|
| 220 |
except Exception as e:
|
| 221 |
logger.error(f"Failed to get history for {session_id}: {e}")
|
| 222 |
-
raise HTTPException(status_code=500, detail=f"Failed to retrieve history: {e}")
|
| 223 |
|
| 224 |
|
| 225 |
@app.delete("/delete-session/{session_id}", response_model=DeleteResponse)
|
|
@@ -233,12 +231,16 @@ async def delete_session(session_id: str):
|
|
| 233 |
return DeleteResponse(status="deleted", session_id=session_id)
|
| 234 |
except Exception as e:
|
| 235 |
logger.error(f"Failed to delete session {session_id}: {e}")
|
| 236 |
-
raise HTTPException(status_code=500, detail=f"Failed to delete session: {e}")
|
| 237 |
|
| 238 |
|
| 239 |
@app.get("/health")
|
| 240 |
async def health_check():
|
| 241 |
-
return {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 242 |
|
| 243 |
|
| 244 |
# -------------------------------------------------------------------
|
|
@@ -246,7 +248,7 @@ async def health_check():
|
|
| 246 |
# -------------------------------------------------------------------
|
| 247 |
if __name__ == "__main__":
|
| 248 |
uvicorn.run(
|
| 249 |
-
"
|
| 250 |
host="0.0.0.0",
|
| 251 |
port=int(os.getenv("PORT", 7860)),
|
| 252 |
reload=False
|
|
|
|
| 3 |
from typing import Optional, List
|
| 4 |
import asyncio
|
| 5 |
from contextlib import asynccontextmanager
|
| 6 |
+
import json
|
| 7 |
|
| 8 |
import pymongo
|
| 9 |
from fastapi import FastAPI, HTTPException
|
|
|
|
| 18 |
)
|
| 19 |
from langchain_core.prompts import PromptTemplate
|
| 20 |
from langchain_groq import ChatGroq
|
| 21 |
+
from langchain_core.output_parsers import StrOutputParser
|
|
|
|
|
|
|
| 22 |
|
| 23 |
# -------------------------------------------------------------------
|
| 24 |
# CONFIGURATION
|
|
|
|
| 45 |
# -------------------------------------------------------------------
|
| 46 |
# CHAT HISTORY HANDLER (MongoDB)
|
| 47 |
# -------------------------------------------------------------------
|
| 48 |
+
class MongoDBChatMessageHistory:
|
| 49 |
"""Chat message history stored in MongoDB."""
|
| 50 |
|
| 51 |
def __init__(self, collection, session_id: str):
|
|
|
|
| 54 |
self.collection.create_index("session_id")
|
| 55 |
|
| 56 |
@property
|
| 57 |
+
def messages(self) -> List[dict]:
|
| 58 |
"""Retrieve messages from MongoDB"""
|
| 59 |
document = self.collection.find_one({"session_id": self.session_id})
|
| 60 |
if document and "history" in document:
|
| 61 |
+
return document["history"]
|
| 62 |
return []
|
| 63 |
|
| 64 |
+
def add_message(self, message: dict) -> None:
|
| 65 |
"""Append the message to MongoDB history"""
|
|
|
|
| 66 |
self.collection.update_one(
|
| 67 |
{"session_id": self.session_id},
|
| 68 |
+
{"$push": {"history": message}},
|
| 69 |
upsert=True,
|
| 70 |
)
|
| 71 |
|
|
|
|
| 84 |
# Initialize LLM
|
| 85 |
llm = ChatGroq(
|
| 86 |
groq_api_key=GROQ_API_KEY,
|
| 87 |
+
model_name="mixtral-8x7b-32768",
|
| 88 |
+
temperature=0.3,
|
| 89 |
+
max_tokens=800
|
| 90 |
)
|
| 91 |
logger.info("ChatGroq LLM initialized successfully")
|
| 92 |
|
|
|
|
| 110 |
# -------------------------------------------------------------------
|
| 111 |
app = FastAPI(
|
| 112 |
title="India Legal Consultation Service",
|
| 113 |
+
description="A bilingual legal consultation API powered by Groq and MongoDB",
|
| 114 |
+
version="3.0.0",
|
| 115 |
lifespan=lifespan
|
| 116 |
)
|
| 117 |
|
|
|
|
| 127 |
# -------------------------------------------------------------------
|
| 128 |
# PYDANTIC MODELS
|
| 129 |
# -------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
class QueryRequest(BaseModel):
|
| 131 |
query: str
|
| 132 |
user_id: str
|
| 133 |
|
| 134 |
+
class ConsultationResponse(BaseModel):
|
| 135 |
response: str
|
| 136 |
+
history: List[dict] = Field(default_factory=list)
|
| 137 |
|
| 138 |
class HistoryResponse(BaseModel):
|
| 139 |
history: List[dict] = Field(default_factory=list)
|
|
|
|
| 144 |
|
| 145 |
|
| 146 |
# -------------------------------------------------------------------
|
| 147 |
+
# LEGAL CONSULTATION PROMPT (FIXED - NO JSON PARSING ISSUES)
|
| 148 |
# -------------------------------------------------------------------
|
| 149 |
consultation_prompt = PromptTemplate(
|
| 150 |
input_variables=["query", "previous_messages"],
|
| 151 |
+
template="""You are an expert Indian legal advisor with deep knowledge of Indian law. Analyze the following query and provide guidance based on Indian laws (IPC, CrPC, Indian Evidence Act, Constitution, Family Law, Contract Law, IT Act, etc.).
|
|
|
|
|
|
|
| 152 |
|
| 153 |
+
User Query: {query}
|
| 154 |
|
| 155 |
+
Previous Conversation: {previous_messages}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
|
| 157 |
+
Instructions:
|
| 158 |
+
1. Provide a clear, practical explanation of the legal issue under Indian law
|
| 159 |
+
2. Include relevant law sections or provisions where applicable
|
| 160 |
+
3. Explain in both Hindi and English using short readable paragraphs
|
| 161 |
+
4. Keep tone professional, supportive, and educational
|
| 162 |
+
5. Provide guidance and awareness, not direct legal verdicts
|
| 163 |
+
6. Identify the main legal area involved (e.g., murder, divorce, contract, cybercrime, tenant rights, property dispute)
|
|
|
|
| 164 |
|
| 165 |
+
Response format: Start with Hindi explanation, then English explanation. End with the legal area identified."""
|
| 166 |
+
)
|
| 167 |
|
| 168 |
# -------------------------------------------------------------------
|
| 169 |
# API ROUTES
|
| 170 |
# -------------------------------------------------------------------
|
| 171 |
+
@app.post("/legal-consultation", response_model=ConsultationResponse)
|
| 172 |
def legal_consultation(request: QueryRequest):
|
| 173 |
if llm is None or history_collection is None:
|
| 174 |
raise HTTPException(status_code=503, detail="Service not ready. Try again later.")
|
| 175 |
|
| 176 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
history_obj = MongoDBChatMessageHistory(history_collection, request.user_id)
|
| 178 |
+
messages = history_obj.messages
|
| 179 |
+
|
| 180 |
+
# Format previous messages
|
| 181 |
+
previous_context = ""
|
| 182 |
+
if messages:
|
| 183 |
+
previous_context = " ".join([f"{msg.get('role', 'user')}: {msg.get('content', '')}" for msg in messages[-5:]])
|
| 184 |
+
|
| 185 |
+
consultation_chain = consultation_prompt | llm | StrOutputParser()
|
| 186 |
+
|
| 187 |
+
result = consultation_chain.invoke({
|
| 188 |
+
"query": request.query,
|
| 189 |
+
"previous_messages": previous_context
|
| 190 |
+
})
|
| 191 |
+
|
| 192 |
+
user_msg = {"role": "user", "content": request.query}
|
| 193 |
+
ai_msg = {"role": "assistant", "content": result}
|
| 194 |
+
|
| 195 |
+
history_obj.add_message(user_msg)
|
| 196 |
+
history_obj.add_message(ai_msg)
|
| 197 |
+
|
| 198 |
+
# Return updated history
|
| 199 |
+
updated_messages = history_obj.messages
|
| 200 |
+
|
| 201 |
+
return ConsultationResponse(
|
| 202 |
response=result,
|
| 203 |
+
history=updated_messages
|
| 204 |
)
|
| 205 |
except Exception as e:
|
| 206 |
logger.error(f"Consultation failed: {e}")
|
| 207 |
+
raise HTTPException(status_code=500, detail=f"Consultation failed: {str(e)}")
|
| 208 |
|
| 209 |
|
| 210 |
@app.get("/get-session-history/{session_id}", response_model=HistoryResponse)
|
|
|
|
| 214 |
try:
|
| 215 |
history_obj = MongoDBChatMessageHistory(history_collection, session_id)
|
| 216 |
messages = history_obj.messages
|
| 217 |
+
return HistoryResponse(history=messages)
|
|
|
|
| 218 |
except Exception as e:
|
| 219 |
logger.error(f"Failed to get history for {session_id}: {e}")
|
| 220 |
+
raise HTTPException(status_code=500, detail=f"Failed to retrieve history: {str(e)}")
|
| 221 |
|
| 222 |
|
| 223 |
@app.delete("/delete-session/{session_id}", response_model=DeleteResponse)
|
|
|
|
| 231 |
return DeleteResponse(status="deleted", session_id=session_id)
|
| 232 |
except Exception as e:
|
| 233 |
logger.error(f"Failed to delete session {session_id}: {e}")
|
| 234 |
+
raise HTTPException(status_code=500, detail=f"Failed to delete session: {str(e)}")
|
| 235 |
|
| 236 |
|
| 237 |
@app.get("/health")
|
| 238 |
async def health_check():
|
| 239 |
+
return {
|
| 240 |
+
"status": "healthy",
|
| 241 |
+
"service": "India Legal Consultation API",
|
| 242 |
+
"version": "3.0.0"
|
| 243 |
+
}
|
| 244 |
|
| 245 |
|
| 246 |
# -------------------------------------------------------------------
|
|
|
|
| 248 |
# -------------------------------------------------------------------
|
| 249 |
if __name__ == "__main__":
|
| 250 |
uvicorn.run(
|
| 251 |
+
"legal_consultation_api:app",
|
| 252 |
host="0.0.0.0",
|
| 253 |
port=int(os.getenv("PORT", 7860)),
|
| 254 |
reload=False
|
test.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 2 |
+
from peft import PeftModel
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
tokenizer = AutoTokenizer.from_pretrained("unsloth/llama-3.2-1b-bnb-4bit")
|
| 6 |
+
|
| 7 |
+
base_model = AutoModelForCausalLM.from_pretrained(
|
| 8 |
+
"unsloth/llama-3.2-1b-bnb-4bit",
|
| 9 |
+
device_map="auto",
|
| 10 |
+
torch_dtype=torch.float16,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
model = PeftModel.from_pretrained(base_model, "MeWan2808/SIT_legalTech_llama3.2")
|
| 14 |
+
model = model.merge_and_unload()
|
| 15 |
+
|
| 16 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 17 |
+
model.to(device)
|