RAG / modules /rag_query.py
Hanzo03's picture
initial commit
1b3e1f1
import os
import chromadb
from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_huggingface.embeddings import HuggingFaceEmbeddings
from langchain_chroma import Chroma
# 1. Initialize RAG Components
REPO_ID = "Qwen/Qwen2.5-7B-Instruct"
COLLECTION_NAME = 'video_analysis_data'
DB_PATH = "./chroma_db"
def run_query(user_query):
if not os.getenv("HUGGINGFACEHUB_API_TOKEN"):
return "Error: Please set HUGGINGFACEHUB_API_TOKEN in your environment variables."
client = chromadb.PersistentClient(path=DB_PATH)
embedding_function = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
vectorstore = Chroma(
client=client,
collection_name=COLLECTION_NAME,
embedding_function=embedding_function
)
retriever = vectorstore.as_retriever(search_kwargs={"k": 5})
# 2. LLM Setup (Qwen via Hugging Face Endpoint)
llm_endpoint = HuggingFaceEndpoint(
repo_id=REPO_ID,
task="text-generation",
max_new_tokens=512,
repetition_penalty=1.1,
huggingfacehub_api_token=os.getenv("HUGGINGFACEHUB_API_TOKEN")
)
# Wrap it in ChatHuggingFace to handle the prompt templates correctly
llm = ChatHuggingFace(llm=llm_endpoint)
template = """
You are an expert Video Content Analyst. Use the Context to answer the Question.
If you don't know the answer, say you don't know.
Infer activity based on detected objects (e.g., people + skateboards = skateboarding).
Context:
{context}
Question: {question}
"""
prompt = ChatPromptTemplate.from_template(template)
rag_chain = (
{"context": retriever, "question": RunnablePassthrough()}
| prompt
| llm
)
response = rag_chain.invoke(user_query)
# Hugging Face responses sometimes need a little cleaning depending on the version
return response.content
# Example Usage:
if __name__ == '__main__':
# Ensure you have indexed data by running rag_indexer.py first
query1 = "What kind of objects were frequently detected in the video?"
answer1 = run_query(query1)
print("\n--- QUERY 1 ---")
print(f"Question: {query1}")
print(f"Answer:\n{answer1}")
print("\n" + "="*50 + "\n")
query2 = "What activity was detected around the 15-second mark in the video?"
answer2 = run_query(query2)
print("\n--- QUERY 2 ---")
print(f"Question: {query2}")
print(f"Answer:\n{answer2}")