receipt_scanner / receipt_gen_agent.py
wolf1997's picture
Update receipt_gen_agent.py
4212010 verified
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.prompts import PromptTemplate
from dotenv import load_dotenv
import os
from typing import List
from typing_extensions import TypedDict
from langchain_core.messages import HumanMessage
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.output_parsers import RetryOutputParser
from langgraph.graph import StateGraph, START, END
import base64
from IPython.display import Image as img, display
from langchain_core.runnables.graph import MermaidDrawMethod
from langgraph.checkpoint.memory import MemorySaver
import json
from pydantic import BaseModel, Field
from io import BytesIO
load_dotenv()
GEMINI_API_KEY=os.getenv('google_api_key')
GEMINI_MODEL='gemini-2.0-flash'
llm = ChatGoogleGenerativeAI(google_api_key=GEMINI_API_KEY, model=GEMINI_MODEL, temperature=0.3)
from os import listdir
from os.path import isfile, join
class State(TypedDict):
prompt: str
image_number: int
image_data: json
image_byte: str
eval: dict
n_retries:int
image_name: str
image_data_list: list
def generate_data_node(state:State):
class Items(BaseModel):
name: str = Field(description='the name of the item')
price : float = Field(description='the price of the item')
quantity: int = Field(description='the quantity of the item')
class Form(BaseModel):
loc_name: str = Field(description='the name of the location if no name put empty str')
table: str = Field(description='table name or number, if no table put an empty str')
address: str = Field(description='the address of the location if no location put empty str')
date: str = Field(description='the date if no date put empty str')
time: str = Field(description='the time if no time put empty str')
items: List[Items] = Field(description= 'list of the items if no items put empty list')
subtotal: float = Field(description= 'the subtotal if no subtotal put 0')
tax: float = Field(description='the tax, if no tax put 0')
total: float = Field(description='the total amount if no total amount put 0')
parser=JsonOutputParser(pydantic_object=Form)
instruction=parser.get_format_instructions()
message = HumanMessage(
content=[
{"type": "text", "text": f"{state.get('prompt')}"+'\n\n'+ instruction},
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{state.get('image_byte')}"},
},
],
)
response=llm.invoke([message])
try:
response=parser.parse(response.content)
return {'image_data':response}
except:
prompt = PromptTemplate(
template="Answer the user query.\n{format_instructions}\n{query}\n",
input_variables=["query"],
partial_variables={"format_instructions": parser.get_format_instructions()},
)
retry_parser = RetryOutputParser.from_llm(parser=parser, llm=llm)
prompt_value=prompt.format_prompt(query=f"{state.get('prompt')}")
response=retry_parser.parse_with_prompt(response.content, prompt_value)
return {'image_data':response}
def evaluate_node(state:State):
class Decision(BaseModel):
decision: str = Field(description='good or modify if changes have to be made')
comment: str = Field(description='the changes to make')
parser=JsonOutputParser(pydantic_object=Decision)
prompt = PromptTemplate(
template="Answer the user query.\n{format_instructions}\n{query}\n",
input_variables=["query"],
partial_variables={"format_instructions": parser.get_format_instructions()},
)
data=state.get('image_data')
query=f" is the {data} correct and makes sense tell the llm what to change, ignore missing data, don't make it up, no explanation or decription needed"
chain = prompt | llm
response=chain.invoke({'query':query})
try:
response=parser.parse(response.content)
except:
retry_parser = RetryOutputParser.from_llm(parser=parser, llm=llm)
prompt_value = prompt.format_prompt(query=query)
response=retry_parser.parse_with_prompt(response.content, prompt_value)
return {'eval': response}
def data_editor_node(state:State):
class Items(BaseModel):
name: str = Field(description='the name of the item')
price : float = Field(description='the price of the item')
quantity: int = Field(description='the quantity of the item')
class Form(BaseModel):
loc_name: str = Field(description='the name of the location if no name put empty str')
table: str = Field(description='table name or number, if no table put an empty str')
address: str = Field(description='the address of the location if no location put empty str')
date: str = Field(description='the date if no date put empty str')
time: str = Field(description='the time if no time put empty str')
items: List[Items] = Field(description= 'list of the items if no items put empty list')
subtotal: float = Field(description= 'the subtotal if no subtotal put 0')
tax: float = Field(description='the tax, if no tax put 0')
total: float = Field(description='the total amount if no total amount put 0')
parser=JsonOutputParser(pydantic_object=Form)
prompt = PromptTemplate(
template="Answer the user query.\n{format_instructions}\n{query}\n",
input_variables=["query"],
partial_variables={"format_instructions": parser.get_format_instructions()},
)
data=state.get('image_data')
query=f"modify this dict: {data} based on these comments {state.get('eval').get('comment')}, return a json"
chain = prompt | llm
response=chain.invoke({'query':query})
try:
response=parser.parse(response.content)
except:
retry_parser = RetryOutputParser.from_llm(parser=parser, llm=llm)
prompt_value = prompt.format_prompt(query=query)
response=retry_parser.parse_with_prompt(response.content, prompt_value)
return {'image_data': response,
'n_retries':state.get('n_retries')+1}
def should_continue(state:State)-> str:
"""
Determine whether the research process should continue based on the current state.
Args:
state: The current state of the agent.
Returns:
str: The next state to transition to ("to_add_data", "to_prompt_editor").
"""
eval=state.get('eval').get('decision')
if eval =='good':
return 'to_add_data'
elif eval =='modify' and state.get('n_retries')<2:
return 'to_data_editor'
else:
return 'to_add_data'
def add_data_node(state:State):
img_number=state.get('image_number')
return {
'n_retries':0,
'image_name':f'{img_number}_new_receipt.jpg'}
class receipt_agent:
def __init__(self):
self.agent=self._setup()
def _setup(self):
agent_builder=StateGraph(State)
agent_builder.add_node('generate_data',generate_data_node)
agent_builder.add_node('evaluate',evaluate_node)
agent_builder.add_node('add_data',add_data_node)
agent_builder.add_node('data_editor',data_editor_node)
agent_builder.add_edge(START,'generate_data')
agent_builder.add_edge('generate_data','evaluate')
# agent_builder.add_edge('evaluate',END)
agent_builder.add_conditional_edges('evaluate', should_continue, {'to_data_editor':'data_editor', 'to_add_data':'add_data'},)
agent_builder.add_edge('data_editor','evaluate')
agent_builder.add_edge('add_data', END)
checkpointer=MemorySaver()
agent=agent_builder.compile(checkpointer=checkpointer)
return agent
def display_graph(self):
return display(
img(
self.agent.get_graph().draw_mermaid_png(
draw_method=MermaidDrawMethod.API,
)
)
)
def get_state(self, state_val:str):
config = {"configurable": {"thread_id": "1"}}
return self.agent.get_state(config).values[state_val]
def receipt_gen(self,image):
config = {"configurable": {"thread_id": "1"}}
buffered=BytesIO()
image.save(buffered, format='JPEG')
image_data = base64.b64encode(buffered.getvalue()).decode("utf-8")
data_list = [f for f in listdir('new_receipt_data') if isfile(join('new_receipt_data', f))]
if not data_list:
data_list=[]
else:
with open(f'new_receipt_data/{data_list[0]}', 'r') as openfile:
# Reading from json file
data_list = json.load(openfile)
response=self.agent.invoke({'prompt':'analyse this receipt and list the items, return a json',
'n_retries':0,
'image_number':len(data_list),
'image_byte': image_data,
'image_data_list':data_list}, config)
image_data=response.get('image_data')
return image_data
def update_state(self, values:dict):
config = {"configurable": {"thread_id": "1"}}
return self.agent.update_state(config,values=values)
def confirm(self,image_data):
config = {"configurable": {"thread_id": "1"}}
if image_data:
data_list=self.agent.get_state(config).values['image_data_list']
img_number=self.agent.get_state(config).values['image_number']
image_name=self.agent.get_state(config).values['image_name']
if not data_list:
data_list=[]
data_list.append({'receipt_name':f'{img_number}_new_receipt.jpg',
'receipt_data':image_data})
self.agent.update_state(config,values={'image_data_list':data_list})
return data_list,image_name