Spaces:
Sleeping
Sleeping
| from langchain_core.output_parsers import JsonOutputParser | |
| from langchain_core.prompts import PromptTemplate | |
| from dotenv import load_dotenv | |
| import os | |
| from typing import List | |
| from typing_extensions import TypedDict | |
| from langchain_core.messages import HumanMessage | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from langchain.output_parsers import RetryOutputParser | |
| from langgraph.graph import StateGraph, START, END | |
| import base64 | |
| from IPython.display import Image as img, display | |
| from langchain_core.runnables.graph import MermaidDrawMethod | |
| from langgraph.checkpoint.memory import MemorySaver | |
| import json | |
| from pydantic import BaseModel, Field | |
| from io import BytesIO | |
| load_dotenv() | |
| GEMINI_API_KEY=os.getenv('google_api_key') | |
| GEMINI_MODEL='gemini-2.0-flash' | |
| llm = ChatGoogleGenerativeAI(google_api_key=GEMINI_API_KEY, model=GEMINI_MODEL, temperature=0.3) | |
| from os import listdir | |
| from os.path import isfile, join | |
| class State(TypedDict): | |
| prompt: str | |
| image_number: int | |
| image_data: json | |
| image_byte: str | |
| eval: dict | |
| n_retries:int | |
| image_name: str | |
| image_data_list: list | |
| def generate_data_node(state:State): | |
| class Items(BaseModel): | |
| name: str = Field(description='the name of the item') | |
| price : float = Field(description='the price of the item') | |
| quantity: int = Field(description='the quantity of the item') | |
| class Form(BaseModel): | |
| loc_name: str = Field(description='the name of the location if no name put empty str') | |
| table: str = Field(description='table name or number, if no table put an empty str') | |
| address: str = Field(description='the address of the location if no location put empty str') | |
| date: str = Field(description='the date if no date put empty str') | |
| time: str = Field(description='the time if no time put empty str') | |
| items: List[Items] = Field(description= 'list of the items if no items put empty list') | |
| subtotal: float = Field(description= 'the subtotal if no subtotal put 0') | |
| tax: float = Field(description='the tax, if no tax put 0') | |
| total: float = Field(description='the total amount if no total amount put 0') | |
| parser=JsonOutputParser(pydantic_object=Form) | |
| instruction=parser.get_format_instructions() | |
| message = HumanMessage( | |
| content=[ | |
| {"type": "text", "text": f"{state.get('prompt')}"+'\n\n'+ instruction}, | |
| { | |
| "type": "image_url", | |
| "image_url": {"url": f"data:image/jpeg;base64,{state.get('image_byte')}"}, | |
| }, | |
| ], | |
| ) | |
| response=llm.invoke([message]) | |
| try: | |
| response=parser.parse(response.content) | |
| return {'image_data':response} | |
| except: | |
| prompt = PromptTemplate( | |
| template="Answer the user query.\n{format_instructions}\n{query}\n", | |
| input_variables=["query"], | |
| partial_variables={"format_instructions": parser.get_format_instructions()}, | |
| ) | |
| retry_parser = RetryOutputParser.from_llm(parser=parser, llm=llm) | |
| prompt_value=prompt.format_prompt(query=f"{state.get('prompt')}") | |
| response=retry_parser.parse_with_prompt(response.content, prompt_value) | |
| return {'image_data':response} | |
| def evaluate_node(state:State): | |
| class Decision(BaseModel): | |
| decision: str = Field(description='good or modify if changes have to be made') | |
| comment: str = Field(description='the changes to make') | |
| parser=JsonOutputParser(pydantic_object=Decision) | |
| prompt = PromptTemplate( | |
| template="Answer the user query.\n{format_instructions}\n{query}\n", | |
| input_variables=["query"], | |
| partial_variables={"format_instructions": parser.get_format_instructions()}, | |
| ) | |
| data=state.get('image_data') | |
| query=f" is the {data} correct and makes sense tell the llm what to change, ignore missing data, don't make it up, no explanation or decription needed" | |
| chain = prompt | llm | |
| response=chain.invoke({'query':query}) | |
| try: | |
| response=parser.parse(response.content) | |
| except: | |
| retry_parser = RetryOutputParser.from_llm(parser=parser, llm=llm) | |
| prompt_value = prompt.format_prompt(query=query) | |
| response=retry_parser.parse_with_prompt(response.content, prompt_value) | |
| return {'eval': response} | |
| def data_editor_node(state:State): | |
| class Items(BaseModel): | |
| name: str = Field(description='the name of the item') | |
| price : float = Field(description='the price of the item') | |
| quantity: int = Field(description='the quantity of the item') | |
| class Form(BaseModel): | |
| loc_name: str = Field(description='the name of the location if no name put empty str') | |
| table: str = Field(description='table name or number, if no table put an empty str') | |
| address: str = Field(description='the address of the location if no location put empty str') | |
| date: str = Field(description='the date if no date put empty str') | |
| time: str = Field(description='the time if no time put empty str') | |
| items: List[Items] = Field(description= 'list of the items if no items put empty list') | |
| subtotal: float = Field(description= 'the subtotal if no subtotal put 0') | |
| tax: float = Field(description='the tax, if no tax put 0') | |
| total: float = Field(description='the total amount if no total amount put 0') | |
| parser=JsonOutputParser(pydantic_object=Form) | |
| prompt = PromptTemplate( | |
| template="Answer the user query.\n{format_instructions}\n{query}\n", | |
| input_variables=["query"], | |
| partial_variables={"format_instructions": parser.get_format_instructions()}, | |
| ) | |
| data=state.get('image_data') | |
| query=f"modify this dict: {data} based on these comments {state.get('eval').get('comment')}, return a json" | |
| chain = prompt | llm | |
| response=chain.invoke({'query':query}) | |
| try: | |
| response=parser.parse(response.content) | |
| except: | |
| retry_parser = RetryOutputParser.from_llm(parser=parser, llm=llm) | |
| prompt_value = prompt.format_prompt(query=query) | |
| response=retry_parser.parse_with_prompt(response.content, prompt_value) | |
| return {'image_data': response, | |
| 'n_retries':state.get('n_retries')+1} | |
| def should_continue(state:State)-> str: | |
| """ | |
| Determine whether the research process should continue based on the current state. | |
| Args: | |
| state: The current state of the agent. | |
| Returns: | |
| str: The next state to transition to ("to_add_data", "to_prompt_editor"). | |
| """ | |
| eval=state.get('eval').get('decision') | |
| if eval =='good': | |
| return 'to_add_data' | |
| elif eval =='modify' and state.get('n_retries')<2: | |
| return 'to_data_editor' | |
| else: | |
| return 'to_add_data' | |
| def add_data_node(state:State): | |
| img_number=state.get('image_number') | |
| return { | |
| 'n_retries':0, | |
| 'image_name':f'{img_number}_new_receipt.jpg'} | |
| class receipt_agent: | |
| def __init__(self): | |
| self.agent=self._setup() | |
| def _setup(self): | |
| agent_builder=StateGraph(State) | |
| agent_builder.add_node('generate_data',generate_data_node) | |
| agent_builder.add_node('evaluate',evaluate_node) | |
| agent_builder.add_node('add_data',add_data_node) | |
| agent_builder.add_node('data_editor',data_editor_node) | |
| agent_builder.add_edge(START,'generate_data') | |
| agent_builder.add_edge('generate_data','evaluate') | |
| # agent_builder.add_edge('evaluate',END) | |
| agent_builder.add_conditional_edges('evaluate', should_continue, {'to_data_editor':'data_editor', 'to_add_data':'add_data'},) | |
| agent_builder.add_edge('data_editor','evaluate') | |
| agent_builder.add_edge('add_data', END) | |
| checkpointer=MemorySaver() | |
| agent=agent_builder.compile(checkpointer=checkpointer) | |
| return agent | |
| def display_graph(self): | |
| return display( | |
| img( | |
| self.agent.get_graph().draw_mermaid_png( | |
| draw_method=MermaidDrawMethod.API, | |
| ) | |
| ) | |
| ) | |
| def get_state(self, state_val:str): | |
| config = {"configurable": {"thread_id": "1"}} | |
| return self.agent.get_state(config).values[state_val] | |
| def receipt_gen(self,image): | |
| config = {"configurable": {"thread_id": "1"}} | |
| buffered=BytesIO() | |
| image.save(buffered, format='JPEG') | |
| image_data = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
| data_list = [f for f in listdir('new_receipt_data') if isfile(join('new_receipt_data', f))] | |
| if not data_list: | |
| data_list=[] | |
| else: | |
| with open(f'new_receipt_data/{data_list[0]}', 'r') as openfile: | |
| # Reading from json file | |
| data_list = json.load(openfile) | |
| response=self.agent.invoke({'prompt':'analyse this receipt and list the items, return a json', | |
| 'n_retries':0, | |
| 'image_number':len(data_list), | |
| 'image_byte': image_data, | |
| 'image_data_list':data_list}, config) | |
| image_data=response.get('image_data') | |
| return image_data | |
| def update_state(self, values:dict): | |
| config = {"configurable": {"thread_id": "1"}} | |
| return self.agent.update_state(config,values=values) | |
| def confirm(self,image_data): | |
| config = {"configurable": {"thread_id": "1"}} | |
| if image_data: | |
| data_list=self.agent.get_state(config).values['image_data_list'] | |
| img_number=self.agent.get_state(config).values['image_number'] | |
| image_name=self.agent.get_state(config).values['image_name'] | |
| if not data_list: | |
| data_list=[] | |
| data_list.append({'receipt_name':f'{img_number}_new_receipt.jpg', | |
| 'receipt_data':image_data}) | |
| self.agent.update_state(config,values={'image_data_list':data_list}) | |
| return data_list,image_name | |