Spaces:
Running
Running
| #load & split data | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| # embed data | |
| from langchain_mistralai.embeddings import MistralAIEmbeddings | |
| # vector store | |
| from langchain_community.vectorstores import FAISS | |
| # prompt | |
| from langchain.prompts import PromptTemplate | |
| # memory | |
| from langchain.memory import ConversationBufferMemory | |
| #llm | |
| from langchain_mistralai.chat_models import ChatMistralAI | |
| #chain modules | |
| from langchain.chains import RetrievalQA | |
| # import PyPDF2 | |
| import os | |
| import re | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| from collections import defaultdict | |
| api_key = os.environ.get("MISTRAL_API_KEY") | |
| class RagModule(): | |
| def __init__(self): | |
| self.mistral_api_key = api_key | |
| self.model_name_embedding = "mistral-embed" | |
| self.embedding_model = MistralAIEmbeddings(model=self.model_name_embedding, mistral_api_key=self.mistral_api_key) | |
| self.chunk_size = 1000 | |
| self.chunk_overlap = 120 | |
| self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap) | |
| self.db_faiss_path = "data/vector_store" | |
| #params llm | |
| self.llm_model = "mistral-small" | |
| self.max_new_tokens = 512 | |
| self.top_p = 0.5 | |
| self.temperature = 0.1 | |
| def split_text(self, text:str) -> list: | |
| """Split the text into chunk | |
| Args: | |
| text (str): _description_ | |
| Returns: | |
| list: _description_ | |
| """ | |
| texts = self.text_splitter.split_text(text) | |
| return texts | |
| def get_metadata(self, texts:list) -> list: | |
| """_summary_ | |
| Args: | |
| texts (list): _description_ | |
| Returns: | |
| list: _description_ | |
| """ | |
| metadatas = [{"source": f'Paragraphe: {i}'} for i in range(len(texts))] | |
| return metadatas | |
| def get_faiss_db(self): | |
| """load local faiss vector store containing all embeddings | |
| """ | |
| db = FAISS.load_local(self.db_faiss_path, self.embedding_model) | |
| return db | |
| def set_custom_prompt(self, prompt_template:str): | |
| """Instantiate prompt template for Q&A retreival for each vectore stores | |
| Args: | |
| prompt_template (str): description of the prompt | |
| input_variables (list): variables in the prompt | |
| """ | |
| prompt = PromptTemplate.from_template( | |
| template=prompt_template, | |
| ) | |
| return prompt | |
| def load_mistral(self): | |
| """instantiate LLM | |
| """ | |
| model_kwargs = { | |
| "mistral_api_key": self.mistral_api_key, | |
| "model": self.llm_model, | |
| "max_new_tokens": self.max_new_tokens, | |
| "top_p": self.top_p, | |
| "temperature": self.temperature, | |
| } | |
| llm = ChatMistralAI(**model_kwargs) | |
| return llm | |
| def retrieval_qa_memory_chain(self, db, prompt_template): | |
| """_summary_ | |
| """ | |
| llm = self.load_mistral() | |
| prompt = self.set_custom_prompt(prompt_template) | |
| memory = ConversationBufferMemory( | |
| memory_key = 'history', | |
| input_key = 'question' | |
| ) | |
| chain_type_kwargs= { | |
| "prompt" : prompt, | |
| "memory" : memory | |
| } | |
| qa_chain = RetrievalQA.from_chain_type( | |
| llm = llm, | |
| chain_type = 'stuff', | |
| retriever = db.as_retriever(search_kwargs={"k":5}), | |
| chain_type_kwargs = chain_type_kwargs, | |
| return_source_documents = True, | |
| ) | |
| return qa_chain | |
| def retrieval_qa_chain(self, db, prompt_template): | |
| """_summary_ | |
| """ | |
| llm = self.load_llm() | |
| prompt = self.set_custom_prompt(prompt_template) | |
| chain_type_kwargs= { | |
| "prompt" : prompt, | |
| } | |
| qa_chain = RetrievalQA.from_chain_type( | |
| llm = llm, | |
| chain_type = 'stuff', | |
| retriever = db.as_retriever(search_kwargs={"k":3}), | |
| chain_type_kwargs = chain_type_kwargs, | |
| return_source_documents = True, | |
| ) | |
| return qa_chain | |
| def get_sources_document(self, source_documents:list) -> dict: | |
| """generate dictionnary with path (as a key) and list of pages associated to one path | |
| Args: | |
| source_document (list): list of documents containing source_document of rag response | |
| Returns: | |
| dict: { | |
| path/to/file1 : [0, 1, 3], | |
| path/to/file2 : [5, 2] | |
| } | |
| """ | |
| sources = defaultdict(list) | |
| for doc in source_documents: | |
| sources[doc.metadata["source"]].append(doc.metadata["page"]) | |
| return sources | |
| def shape_answer_with_source(self, answer: str, sources: dict): | |
| """_summary_ | |
| Args: | |
| answer (str): _description_ | |
| source (dict): _description_ | |
| """ | |
| pattern = r"^(.+)\/([^\/]+)$" | |
| source_msg = "" | |
| for path, page in sources.items(): | |
| file = re.findall(pattern, path)[0][1] | |
| source_msg += f"\nFichier: {file} - Page: {page}" | |
| answer += f"\n{source_msg}" | |
| return answer |