|
|
import os |
|
|
import chromadb |
|
|
from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace |
|
|
from langchain_core.prompts import ChatPromptTemplate |
|
|
from langchain_core.runnables import RunnablePassthrough |
|
|
from langchain_huggingface.embeddings import HuggingFaceEmbeddings |
|
|
from langchain_chroma import Chroma |
|
|
|
|
|
|
|
|
REPO_ID = "Qwen/Qwen2.5-7B-Instruct" |
|
|
COLLECTION_NAME = 'video_analysis_data' |
|
|
DB_PATH = "./chroma_db" |
|
|
|
|
|
def run_query(user_query): |
|
|
if not os.getenv("HUGGINGFACEHUB_API_TOKEN"): |
|
|
return "Error: Please set HUGGINGFACEHUB_API_TOKEN in your environment variables." |
|
|
|
|
|
client = chromadb.PersistentClient(path=DB_PATH) |
|
|
embedding_function = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2") |
|
|
|
|
|
vectorstore = Chroma( |
|
|
client=client, |
|
|
collection_name=COLLECTION_NAME, |
|
|
embedding_function=embedding_function |
|
|
) |
|
|
|
|
|
retriever = vectorstore.as_retriever(search_kwargs={"k": 5}) |
|
|
|
|
|
|
|
|
llm_endpoint = HuggingFaceEndpoint( |
|
|
repo_id=REPO_ID, |
|
|
task="text-generation", |
|
|
max_new_tokens=512, |
|
|
repetition_penalty=1.1, |
|
|
huggingfacehub_api_token=os.getenv("HUGGINGFACEHUB_API_TOKEN") |
|
|
) |
|
|
|
|
|
|
|
|
llm = ChatHuggingFace(llm=llm_endpoint) |
|
|
|
|
|
template = """ |
|
|
You are an expert Video Content Analyst. Use the Context to answer the Question. |
|
|
If you don't know the answer, say you don't know. |
|
|
Infer activity based on detected objects (e.g., people + skateboards = skateboarding). |
|
|
|
|
|
Context: |
|
|
{context} |
|
|
|
|
|
Question: {question} |
|
|
""" |
|
|
prompt = ChatPromptTemplate.from_template(template) |
|
|
|
|
|
rag_chain = ( |
|
|
{"context": retriever, "question": RunnablePassthrough()} |
|
|
| prompt |
|
|
| llm |
|
|
) |
|
|
|
|
|
response = rag_chain.invoke(user_query) |
|
|
|
|
|
return response.content |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
|
|
|
query1 = "What kind of objects were frequently detected in the video?" |
|
|
answer1 = run_query(query1) |
|
|
print("\n--- QUERY 1 ---") |
|
|
print(f"Question: {query1}") |
|
|
print(f"Answer:\n{answer1}") |
|
|
|
|
|
print("\n" + "="*50 + "\n") |
|
|
|
|
|
query2 = "What activity was detected around the 15-second mark in the video?" |
|
|
answer2 = run_query(query2) |
|
|
print("\n--- QUERY 2 ---") |
|
|
print(f"Question: {query2}") |
|
|
print(f"Answer:\n{answer2}") |
|
|
|