OnurKerimoglu commited on
Commit
2f3c7e6
·
1 Parent(s): 22ca913

introduced nb/chatbot_agentic.ipynb

Browse files
Files changed (1) hide show
  1. notebooks/chatbot_agentic.ipynb +254 -0
notebooks/chatbot_agentic.ipynb ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import dotenv\n",
10
+ "dotenv.load_dotenv(dotenv.find_dotenv())"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": null,
16
+ "metadata": {},
17
+ "outputs": [],
18
+ "source": [
19
+ "import json\n",
20
+ "from typing import Annotated, List\n",
21
+ "from typing_extensions import TypedDict\n",
22
+ "from langchain_core.messages import ToolMessage\n",
23
+ "from langgraph.graph import StateGraph, START, END\n",
24
+ "from langgraph.graph.message import add_messages\n",
25
+ "\n",
26
+ "\n",
27
+ "# Define a State class, that each node in the graph will need\n",
28
+ "class State(TypedDict):\n",
29
+ " # Messages have the type \"list\". The `add_messages` function\n",
30
+ " # in the annotation defines how this state key should be updated\n",
31
+ " # (in this case, it appends messages to the list, rather than overwriting them)\n",
32
+ " messages: Annotated[list, add_messages]\n",
33
+ "\n",
34
+ "# Initialize the graph as a stategraph:\n",
35
+ "graph_builder = StateGraph(State)"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "code",
40
+ "execution_count": null,
41
+ "metadata": {},
42
+ "outputs": [],
43
+ "source": [
44
+ "def create_llm(use_model):\n",
45
+ " # Create the language model\n",
46
+ " if use_model == 'gpt-4o-mini':\n",
47
+ " from langchain_openai import ChatOpenAI\n",
48
+ " print(f'As llm, using OpenAI model: {use_model}')\n",
49
+ " llm = ChatOpenAI(\n",
50
+ " model_name=\"gpt-4o-mini\",\n",
51
+ " temperature=0)\n",
52
+ " elif use_model == 'zephyr-7b-alpha':\n",
53
+ " from langchain_huggingface import HuggingFaceEndpoint\n",
54
+ " print(f'As llm, using HF-Endpint: {use_model}')\n",
55
+ " llm = HuggingFaceEndpoint(\n",
56
+ " repo_id=f\"huggingfaceh4/{use_model}\",\n",
57
+ " temperature=0.1,\n",
58
+ " max_new_tokens=512\n",
59
+ " )\n",
60
+ " return llm"
61
+ ]
62
+ },
63
+ {
64
+ "cell_type": "code",
65
+ "execution_count": null,
66
+ "metadata": {},
67
+ "outputs": [],
68
+ "source": [
69
+ "# Define tools to bind to llm\n",
70
+ "def create_wiki_tool(verbose):\n",
71
+ " print('Creating wiki tool')\n",
72
+ " # Let's define a wikipedia-lookup tool \n",
73
+ " from langchain_community.tools import WikipediaQueryRun\n",
74
+ " from langchain_community.utilities import WikipediaAPIWrapper\n",
75
+ "\n",
76
+ " api_wrapper = WikipediaAPIWrapper(top_k_results=1, doc_content_chars_max=5000)\n",
77
+ " tool_wiki = WikipediaQueryRun(api_wrapper=api_wrapper)\n",
78
+ " if verbose:\n",
79
+ " test_search = \"quantum mechanics\"\n",
80
+ " print(f\"Testing wiki tool, with search key: {test_search}\")\n",
81
+ " response = tool_wiki.run({\"query\": test_search})\n",
82
+ " print(f\"Response: {response}\")\n",
83
+ " return tool_wiki\n",
84
+ "\n",
85
+ "def get_tools(verbose):\n",
86
+ " print('Gathering tools')\n",
87
+ " tool_wiki = create_wiki_tool(verbose=verbose) \n",
88
+ " tools = [tool_wiki]\n",
89
+ " return tools"
90
+ ]
91
+ },
92
+ {
93
+ "cell_type": "code",
94
+ "execution_count": null,
95
+ "metadata": {},
96
+ "outputs": [],
97
+ "source": [
98
+ "def chatbot(state: State):\n",
99
+ " tools = get_tools(verbose=False)\n",
100
+ " llm = create_llm(use_model='gpt-4o-mini')\n",
101
+ " # llm = create_llm(use_model='zephyr-7b-alpha')\n",
102
+ " llm_with_tools = llm.bind_tools(tools)\n",
103
+ " return {\"messages\": [llm_with_tools.invoke(state[\"messages\"])]}"
104
+ ]
105
+ },
106
+ {
107
+ "cell_type": "code",
108
+ "execution_count": null,
109
+ "metadata": {},
110
+ "outputs": [],
111
+ "source": [
112
+ "class BasicToolNode:\n",
113
+ " \"\"\"A node that runs the tools requested in the last AIMessage.\"\"\"\n",
114
+ "\n",
115
+ " def __init__(self, tools: list) -> None:\n",
116
+ " self.tools_by_name = {tool.name: tool for tool in tools}\n",
117
+ "\n",
118
+ " def __call__(self, inputs: dict):\n",
119
+ " if messages := inputs.get(\"messages\", []):\n",
120
+ " message = messages[-1]\n",
121
+ " else:\n",
122
+ " raise ValueError(\"No message found in input\")\n",
123
+ " outputs = []\n",
124
+ " for tool_call in message.tool_calls:\n",
125
+ " tool_result = self.tools_by_name[tool_call[\"name\"]].invoke(\n",
126
+ " tool_call[\"args\"]\n",
127
+ " )\n",
128
+ " outputs.append(\n",
129
+ " ToolMessage(\n",
130
+ " content=json.dumps(tool_result),\n",
131
+ " name=tool_call[\"name\"],\n",
132
+ " tool_call_id=tool_call[\"id\"],\n",
133
+ " )\n",
134
+ " )\n",
135
+ " return {\"messages\": outputs}\n",
136
+ "\n",
137
+ "def route_tools(state: State):\n",
138
+ " \"\"\"\n",
139
+ " Use in the conditional_edge to route to the ToolNode if the last message\n",
140
+ " has tool calls. Otherwise, route to the end.\n",
141
+ " \"\"\"\n",
142
+ " if isinstance(state, list):\n",
143
+ " ai_message = state[-1]\n",
144
+ " elif messages := state.get(\"messages\", []):\n",
145
+ " ai_message = messages[-1]\n",
146
+ " else:\n",
147
+ " raise ValueError(f\"No messages found in input state to tool_edge: {state}\")\n",
148
+ " if hasattr(ai_message, \"tool_calls\") and len(ai_message.tool_calls) > 0:\n",
149
+ " routing_decision = \"tools\"\n",
150
+ " else:\n",
151
+ " routing_decision = END\n",
152
+ " return routing_decision\n"
153
+ ]
154
+ },
155
+ {
156
+ "cell_type": "code",
157
+ "execution_count": null,
158
+ "metadata": {},
159
+ "outputs": [],
160
+ "source": [
161
+ "# Define nodes:\n",
162
+ "graph_builder.add_node(\"chatbot\", chatbot)\n",
163
+ "tool_node = BasicToolNode(tools=get_tools(verbose=False))\n",
164
+ "graph_builder.add_node(\"tools\", tool_node)\n",
165
+ "\n",
166
+ "# Define edges:\n",
167
+ "# Entry Point\n",
168
+ "graph_builder.add_edge(START, \"chatbot\")\n",
169
+ "# Conditional Edge between the chatbot and the tool node\n",
170
+ "# The `route_tools` function returns \"tools\" if the chatbot asks to use a tool, and \"END\" if\n",
171
+ "# it is fine directly responding. This conditional routing defines the main agent loop.\n",
172
+ "graph_builder.add_conditional_edges(\n",
173
+ " \"chatbot\",\n",
174
+ " route_tools,\n",
175
+ " # The following dictionary lets you tell the graph to interpret the condition's outputs as a specific node\n",
176
+ " # It defaults to the identity function, but if you\n",
177
+ " # want to use a node named something else apart from \"tools\",\n",
178
+ " # You can update the value of the dictionary to something else\n",
179
+ " # e.g., \"tools\": \"my_tools\"\n",
180
+ " {\"tools\": \"tools\", END: END},\n",
181
+ ")\n",
182
+ "# Edge between the tool node and the chatbot\n",
183
+ "# Any time a tool is called, we return to the chatbot to decide the next step\n",
184
+ "graph_builder.add_edge(\"tools\", \"chatbot\")\n",
185
+ "\n",
186
+ "graph = graph_builder.compile()"
187
+ ]
188
+ },
189
+ {
190
+ "cell_type": "code",
191
+ "execution_count": null,
192
+ "metadata": {},
193
+ "outputs": [],
194
+ "source": [
195
+ "from IPython.display import Image, display\n",
196
+ "\n",
197
+ "try:\n",
198
+ " display(Image(graph.get_graph().draw_mermaid_png()))\n",
199
+ "except Exception:\n",
200
+ " # This requires some extra dependencies and is optional\n",
201
+ " pass"
202
+ ]
203
+ },
204
+ {
205
+ "cell_type": "code",
206
+ "execution_count": null,
207
+ "metadata": {},
208
+ "outputs": [],
209
+ "source": [
210
+ "def stream_graph_updates(user_input: str):\n",
211
+ " for event in graph.stream({\"messages\": [(\"user\", user_input)]}):\n",
212
+ " for value in event.values():\n",
213
+ " print(\"Assistant:\", value[\"messages\"][-1].content)\n",
214
+ "\n",
215
+ "\n",
216
+ "while True:\n",
217
+ " try:\n",
218
+ " user_input = input(\"User: \")\n",
219
+ " if user_input.lower() in [\"quit\", \"exit\", \"q\"]:\n",
220
+ " print(\"Goodbye!\")\n",
221
+ " break\n",
222
+ "\n",
223
+ " stream_graph_updates(user_input)\n",
224
+ " except:\n",
225
+ " # fallback if input() is not available\n",
226
+ " user_input = \"What do you know about LangGraph?\"\n",
227
+ " print(\"User: \" + user_input)\n",
228
+ " stream_graph_updates(user_input)\n",
229
+ " break"
230
+ ]
231
+ }
232
+ ],
233
+ "metadata": {
234
+ "kernelspec": {
235
+ "display_name": "langchain_311",
236
+ "language": "python",
237
+ "name": "python3"
238
+ },
239
+ "language_info": {
240
+ "codemirror_mode": {
241
+ "name": "ipython",
242
+ "version": 3
243
+ },
244
+ "file_extension": ".py",
245
+ "mimetype": "text/x-python",
246
+ "name": "python",
247
+ "nbconvert_exporter": "python",
248
+ "pygments_lexer": "ipython3",
249
+ "version": "3.11.1"
250
+ }
251
+ },
252
+ "nbformat": 4,
253
+ "nbformat_minor": 2
254
+ }