File size: 5,069 Bytes
7e6e9de
7131533
e4a6388
7fce9f4
 
 
 
e4a6388
6de2253
e4a6388
d1299a1
7fce9f4
 
9194df0
7e6e9de
6de2253
 
8307eaf
6de2253
8307eaf
7fce9f4
7e6e9de
7fce9f4
7e6e9de
8307eaf
7e6e9de
 
 
 
7fce9f4
 
7e6e9de
 
1b7ccbc
7e6e9de
1b7ccbc
 
7e6e9de
1b7ccbc
7e6e9de
 
7fce9f4
 
7e6e9de
7fce9f4
7e6e9de
7fce9f4
7e6e9de
 
 
7fce9f4
 
7e6e9de
6de2253
 
 
e7dd7fc
7e6e9de
 
 
 
 
e7dd7fc
7fce9f4
6de2253
7fce9f4
 
6de2253
7e6e9de
 
6de2253
7e6e9de
 
 
 
8307eaf
6de2253
 
 
 
 
 
 
 
 
7fce9f4
7131533
7fce9f4
7e6e9de
 
6de2253
 
 
 
 
 
 
 
 
 
 
 
7fce9f4
7e6e9de
7fce9f4
 
7e6e9de
 
7fce9f4
 
 
6de2253
 
 
 
7e6e9de
 
 
7fce9f4
7e6e9de
7fce9f4
7e6e9de
7131533
7e6e9de
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import os, tempfile, streamlit as st
from typing import List, IO, Tuple
from dotenv import load_dotenv
from PyPDF2 import PdfReader
from docx import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import Document as LangchainDocument
from langchain_community.vectorstores import FAISS
from cerebras.cloud.sdk import Cerebras # <-- NEW
from langchain.prompts import PromptTemplate
from langchain_together.embeddings import TogetherEmbeddings

load_dotenv()

# ---------- Helpers ---------------------------------------------------------
def get_cerebras_api_key() -> str:
    key = os.environ.get("CEREBRAS_API_KEY") or st.secrets.get("CEREBRAS_API_KEY", None)
    if not key:
        raise EnvironmentError("CEREBRAS_API_KEY not found in env or Streamlit secrets.")
    return key

# ---------- File-reading utilities -----------------------------------------
def get_pdf_text(pdf_docs: List[IO[bytes]]) -> str:
    txt = ""
    for pdf in pdf_docs:
        for page in PdfReader(pdf).pages:
            if (t := page.extract_text()):
                txt += t + "\n"
    return txt

def get_docx_text(docx_docs: List[IO[bytes]]) -> str:
    txt = ""
    for d in docx_docs:
        with tempfile.NamedTemporaryFile(delete=False, suffix=".docx") as tmp:
            tmp.write(d.getvalue()); tmp.flush()
        try:
            doc = Document(tmp.name)
            txt += "\n".join(p.text for p in doc.paragraphs) + "\n"
        finally:
            os.unlink(tmp.name)
    return txt

def get_text_chunks(text: str) -> List[str]:
    return RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200).split_text(text)

# ---------- Vector-store build & save --------------------------------------
def get_vector_store(text_chunks: List[str]) -> None:
    api_key = get_together_api_key()
    embeddings = TogetherEmbeddings(model="BAAI/bge-base-en-v1.5", api_key=api_key)
    vector_store = FAISS.from_texts(text_chunks, embedding=embeddings)
    vector_store.save_local("faiss_index")

# ---------- QA chain helpers ----------------------------------------------
def get_conversational_chain() -> Tuple[Cerebras, PromptTemplate]:
    # Cerebras client is instantiated here, prompt template unchanged
    client = Cerebras(api_key=get_cerebras_api_key())
    prompt = PromptTemplate(
        template=(
            "As a professional assistant, provide a detailed and formally written "
            "answer to the question using the provided context.\n\nContext:\n{context}\n\n"
            "Question:\n{question}\n\nAnswer:"
        ),
        input_variables=["context", "question"]
    )
    return client, prompt

def self_assess(question: str) -> str:
    client = Cerebras(api_key=get_cerebras_api_key())
    msgs = [
        {"role": "system", "content": "You are an expert assistant…"},
        {"role": "user", "content": (
            "If you can confidently answer the following question from your own "
            "knowledge, do so; otherwise reply with 'NEED_RETRIEVAL'.\n\n"
            f"Question: {question}"
        )}
    ]
    result = client.chat.completions.create(
        messages=msgs,
        model="llama-3.3-70b",
        max_completion_tokens=1024,
        temperature=0.2,
        top_p=1,
        stream=False
    )
    return result.choices[0].message.content.strip()

def process_docs_for_query(docs: List[LangchainDocument], question: str) -> str:
    if not docs:
        return "Sorry, I couldn’t find relevant info in the documents."
    ctx = "\n\n".join(d.page_content for d in docs)
    client, prompt = get_conversational_chain()
    prompt_text = prompt.format(context=ctx, question=question)
    result = client.chat.completions.create(
        messages=[{"role": "user", "content": prompt_text}],
        model="llama-3.3-70b",
        max_completion_tokens=1024,
        temperature=0.2,
        top_p=1,
        stream=False
    )
    return result.choices[0].message.content


# ---------- Main user-query orchestrator -----------------------------------
def user_input(user_question: str) -> None:
    assessment = self_assess(user_question)
    need_retrieval = assessment.upper() == "NEED_RETRIEVAL"
    st.info("🔍 Searching documents…" if need_retrieval else "💡 Using model knowledge…")

    try:
        if need_retrieval:
            # Embeddings usage remains, need to replace TogetherEmbeddings if you want Cerebras embedding alternative
            api_key = get_cerebras_api_key()
            # Comment or replace TogetherEmbeddings below if unsupported
            embeddings = TogetherEmbeddings(model="BAAI/bge-base-en-v1.5", api_key=api_key)  
            vs = FAISS.load_local("faiss_index", embeddings, allow_dangerous_deserialization=True)
            docs = vs.similarity_search(user_question)
            answer = process_docs_for_query(docs, user_question)
        else:
            answer = assessment
        st.markdown("### Answer")
        st.markdown(answer)
    except Exception as e:
        st.error(f"⚠️ Error: {e}")