adi-123 commited on
Commit
e4a6388
·
verified ·
1 Parent(s): 44d3012

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +49 -21
utils.py CHANGED
@@ -2,22 +2,22 @@ import os
2
  import tempfile
3
  import streamlit as st
4
  from typing import List, IO, Tuple
 
5
  from PyPDF2 import PdfReader
6
  from docx import Document
7
  from langchain.text_splitter import RecursiveCharacterTextSplitter
8
- from langchain_community.vectorstores import FAISS
9
- from langchain.prompts import PromptTemplate
10
  from langchain.schema import Document as LangchainDocument
11
- from langchain_together.embeddings import TogetherEmbeddings
12
  from langchain_together.chat_models import ChatTogether
13
- from dotenv import load_dotenv
 
14
 
15
- # Load from .env if available
16
  load_dotenv()
17
 
18
  def get_together_api_key() -> str:
19
  """
20
- Retrieves the Together API key from environment or Streamlit secrets.
21
  """
22
  key = os.getenv("TOGETHER_API_KEY")
23
  if not key:
@@ -26,10 +26,13 @@ def get_together_api_key() -> str:
26
  except Exception:
27
  pass
28
  if not key:
29
- raise EnvironmentError("TOGETHER_API_KEY not found. Set in env or Hugging Face secrets.")
30
  return key
31
 
32
  def get_pdf_text(pdf_docs: List[IO[bytes]]) -> str:
 
 
 
33
  text = ""
34
  for pdf in pdf_docs:
35
  pdf_reader = PdfReader(pdf)
@@ -40,15 +43,18 @@ def get_pdf_text(pdf_docs: List[IO[bytes]]) -> str:
40
  return text
41
 
42
  def get_docx_text(docx_docs: List[IO[bytes]]) -> str:
 
 
 
43
  text = ""
44
  for docx in docx_docs:
45
- with tempfile.NamedTemporaryFile(delete=False, suffix='.docx') as temp_file:
46
  try:
47
  temp_file.write(docx.getvalue())
48
  temp_file.flush()
49
  doc = Document(temp_file.name)
50
  doc_text = [p.text for p in doc.paragraphs]
51
- text += '\n'.join(doc_text) + "\n"
52
  except Exception as e:
53
  st.warning(f"Warning: Could not process document {docx.name}: {str(e)}")
54
  finally:
@@ -59,20 +65,34 @@ def get_docx_text(docx_docs: List[IO[bytes]]) -> str:
59
  return text
60
 
61
  def get_text_chunks(text: str) -> List[str]:
 
 
 
62
  splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
63
  return splitter.split_text(text)
64
 
65
  def get_vector_store(text_chunks: List[str]) -> None:
 
 
 
66
  api_key = get_together_api_key()
67
- embeddings = TogetherEmbeddings(
68
- model="togethercomputer/m2-bert-80M-8k-retrieval",
69
- api_key=api_key
 
70
  )
 
71
  documents = [LangchainDocument(page_content=chunk) for chunk in text_chunks]
72
- vector_store = FAISS.from_documents(documents, embedding=embeddings)
 
 
 
73
  vector_store.save_local("faiss_index")
74
 
75
  def get_conversational_chain() -> Tuple[ChatTogether, PromptTemplate]:
 
 
 
76
  api_key = get_together_api_key()
77
  llm = ChatTogether(
78
  model="meta-llama/Llama-3.3-70B-Instruct-Turbo-Free",
@@ -98,6 +118,9 @@ def get_conversational_chain() -> Tuple[ChatTogether, PromptTemplate]:
98
  return llm, prompt
99
 
100
  def self_assess(question: str) -> str:
 
 
 
101
  api_key = get_together_api_key()
102
  llm = ChatTogether(
103
  model="meta-llama/Llama-3.3-70B-Instruct-Turbo-Free",
@@ -124,15 +147,22 @@ def self_assess(question: str) -> str:
124
  return response.content.strip()
125
 
126
  def process_docs_for_query(docs: List[LangchainDocument], question: str) -> str:
 
 
 
127
  if not docs:
128
  return "I couldn't find relevant information in your uploaded documents to answer that question."
 
129
  context = "\n\n".join(doc.page_content for doc in docs)
130
  llm, prompt = get_conversational_chain()
131
- final_prompt = prompt.format(context=context, question=question)
132
- response = llm.invoke(final_prompt)
133
  return response.content
134
 
135
  def user_input(user_question: str) -> None:
 
 
 
136
  assessment = self_assess(user_question)
137
 
138
  if assessment.strip().upper() == "NEED_RETRIEVAL":
@@ -145,11 +175,8 @@ def user_input(user_question: str) -> None:
145
  try:
146
  if need_retrieval:
147
  api_key = get_together_api_key()
148
- embeddings = TogetherEmbeddings(
149
- model="togethercomputer/m2-bert-80M-8k-retrieval",
150
- api_key=api_key
151
- )
152
- vector_store = FAISS.load_local("faiss_index", embeddings, allow_dangerous_deserialization=True)
153
  docs = vector_store.similarity_search(user_question)
154
  response = process_docs_for_query(docs, user_question)
155
  else:
@@ -157,5 +184,6 @@ def user_input(user_question: str) -> None:
157
 
158
  st.markdown("### Answer")
159
  st.markdown(response)
 
160
  except Exception as e:
161
- st.error(f"⚠️ An error occurred: {e}")
 
2
  import tempfile
3
  import streamlit as st
4
  from typing import List, IO, Tuple
5
+ from dotenv import load_dotenv
6
  from PyPDF2 import PdfReader
7
  from docx import Document
8
  from langchain.text_splitter import RecursiveCharacterTextSplitter
 
 
9
  from langchain.schema import Document as LangchainDocument
10
+ from langchain_community.vectorstores import FAISS
11
  from langchain_together.chat_models import ChatTogether
12
+ from langchain.prompts import PromptTemplate
13
+ from together import Together
14
 
15
+ # Load environment variables
16
  load_dotenv()
17
 
18
  def get_together_api_key() -> str:
19
  """
20
+ Retrieves the Together API key from environment variables or Streamlit secrets.
21
  """
22
  key = os.getenv("TOGETHER_API_KEY")
23
  if not key:
 
26
  except Exception:
27
  pass
28
  if not key:
29
+ raise EnvironmentError("TOGETHER_API_KEY not found in env or Hugging Face secrets.")
30
  return key
31
 
32
  def get_pdf_text(pdf_docs: List[IO[bytes]]) -> str:
33
+ """
34
+ Extract text content from a list of PDF files.
35
+ """
36
  text = ""
37
  for pdf in pdf_docs:
38
  pdf_reader = PdfReader(pdf)
 
43
  return text
44
 
45
  def get_docx_text(docx_docs: List[IO[bytes]]) -> str:
46
+ """
47
+ Extract text content from a list of Word documents.
48
+ """
49
  text = ""
50
  for docx in docx_docs:
51
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".docx") as temp_file:
52
  try:
53
  temp_file.write(docx.getvalue())
54
  temp_file.flush()
55
  doc = Document(temp_file.name)
56
  doc_text = [p.text for p in doc.paragraphs]
57
+ text += "\n".join(doc_text) + "\n"
58
  except Exception as e:
59
  st.warning(f"Warning: Could not process document {docx.name}: {str(e)}")
60
  finally:
 
65
  return text
66
 
67
  def get_text_chunks(text: str) -> List[str]:
68
+ """
69
+ Split text into manageable chunks for processing.
70
+ """
71
  splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
72
  return splitter.split_text(text)
73
 
74
  def get_vector_store(text_chunks: List[str]) -> None:
75
+ """
76
+ Create and store a FAISS vector store using Together AI embeddings.
77
+ """
78
  api_key = get_together_api_key()
79
+ client = Together(api_key=api_key)
80
+ response = client.embeddings.create(
81
+ model="BAAI/bge-base-en-v1.5",
82
+ input=text_chunks
83
  )
84
+ embeddings = [item["embedding"] for item in response.data]
85
  documents = [LangchainDocument(page_content=chunk) for chunk in text_chunks]
86
+ vector_store = FAISS.from_documents(
87
+ documents,
88
+ embedding_function=lambda _: embeddings.pop(0)
89
+ )
90
  vector_store.save_local("faiss_index")
91
 
92
  def get_conversational_chain() -> Tuple[ChatTogether, PromptTemplate]:
93
+ """
94
+ Initialize the LLM and prompt template for answering questions.
95
+ """
96
  api_key = get_together_api_key()
97
  llm = ChatTogether(
98
  model="meta-llama/Llama-3.3-70B-Instruct-Turbo-Free",
 
118
  return llm, prompt
119
 
120
  def self_assess(question: str) -> str:
121
+ """
122
+ Determine whether the AI can answer the question directly or needs document retrieval.
123
+ """
124
  api_key = get_together_api_key()
125
  llm = ChatTogether(
126
  model="meta-llama/Llama-3.3-70B-Instruct-Turbo-Free",
 
147
  return response.content.strip()
148
 
149
  def process_docs_for_query(docs: List[LangchainDocument], question: str) -> str:
150
+ """
151
+ Use retrieved documents and the LLM to generate an answer.
152
+ """
153
  if not docs:
154
  return "I couldn't find relevant information in your uploaded documents to answer that question."
155
+
156
  context = "\n\n".join(doc.page_content for doc in docs)
157
  llm, prompt = get_conversational_chain()
158
+ formatted_prompt = prompt.format(context=context, question=question)
159
+ response = llm.invoke(formatted_prompt)
160
  return response.content
161
 
162
  def user_input(user_question: str) -> None:
163
+ """
164
+ Process the user's question, decide on retrieval or not, and display the answer.
165
+ """
166
  assessment = self_assess(user_question)
167
 
168
  if assessment.strip().upper() == "NEED_RETRIEVAL":
 
175
  try:
176
  if need_retrieval:
177
  api_key = get_together_api_key()
178
+ client = Together(api_key=api_key)
179
+ vector_store = FAISS.load_local("faiss_index", embedding_function=lambda x: [0.0]*768, allow_dangerous_deserialization=True)
 
 
 
180
  docs = vector_store.similarity_search(user_question)
181
  response = process_docs_for_query(docs, user_question)
182
  else:
 
184
 
185
  st.markdown("### Answer")
186
  st.markdown(response)
187
+
188
  except Exception as e:
189
+ st.error(f"⚠️ An error occurred while processing your question: {e}")