{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import dotenv\n", "dotenv.load_dotenv(dotenv.find_dotenv())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import json\n", "from typing import Annotated, List\n", "from typing_extensions import TypedDict\n", "from langgraph.checkpoint.memory import MemorySaver\n", "from langchain_core.messages import ToolMessage\n", "from langgraph.graph import StateGraph, START, END\n", "from langgraph.graph.message import add_messages\n", "from langgraph.prebuilt import ToolNode, tools_condition\n", "\n", "\n", "# Define a State class, that each node in the graph will need\n", "class State(TypedDict):\n", " # Messages have the type \"list\". The `add_messages` function\n", " # in the annotation defines how this state key should be updated\n", " # (in this case, it appends messages to the list, rather than overwriting them)\n", " messages: Annotated[list, add_messages]\n", "\n", "# Initialize the graph as a stategraph:\n", "graph_builder = StateGraph(State)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def create_llm(use_model):\n", " # Create the language model\n", " if use_model == 'gpt-4o-mini':\n", " from langchain_openai import ChatOpenAI\n", " print(f'As llm, using OpenAI model: {use_model}')\n", " llm = ChatOpenAI(\n", " model_name=\"gpt-4o-mini\",\n", " temperature=0)\n", " elif use_model == 'zephyr-7b-alpha':\n", " from langchain_huggingface import HuggingFaceEndpoint\n", " print(f'As llm, using HF-Endpint: {use_model}')\n", " llm = HuggingFaceEndpoint(\n", " repo_id=f\"huggingfaceh4/{use_model}\",\n", " temperature=0.1,\n", " max_new_tokens=512\n", " )\n", " return llm" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Define tools to bind to llm\n", "def create_wiki_tool(verbose):\n", " print('Creating wiki tool')\n", " # Let's define a wikipedia-lookup tool \n", " from langchain_community.tools import WikipediaQueryRun\n", " from langchain_community.utilities import WikipediaAPIWrapper\n", "\n", " api_wrapper = WikipediaAPIWrapper(top_k_results=1, doc_content_chars_max=5000)\n", " tool_wiki = WikipediaQueryRun(api_wrapper=api_wrapper)\n", " if verbose:\n", " test_search = \"quantum mechanics\"\n", " print(f\"Testing wiki tool, with search key: {test_search}\")\n", " response = tool_wiki.run({\"query\": test_search})\n", " print(f\"Response: {response}\")\n", " return tool_wiki\n", "\n", "def get_tools(verbose):\n", " print('Gathering tools')\n", " tool_wiki = create_wiki_tool(verbose=verbose) \n", " tools = [tool_wiki]\n", " return tools" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def chatbot(state: State):\n", " tools = get_tools(verbose=False)\n", " llm = create_llm(use_model='gpt-4o-mini')\n", " # llm = create_llm(use_model='zephyr-7b-alpha')\n", " llm_with_tools = llm.bind_tools(tools)\n", " return {\"messages\": [llm_with_tools.invoke(state[\"messages\"])]}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Simpler versions of ToolNode and tools_condition\n", "if False:\n", " class BasicToolNode:\n", " \"\"\"A node that runs the tools requested in the last AIMessage.\"\"\"\n", "\n", " def __init__(self, tools: list) -> None:\n", " self.tools_by_name = {tool.name: tool for tool in tools}\n", "\n", " def __call__(self, inputs: dict):\n", " if messages := inputs.get(\"messages\", []):\n", " message = messages[-1]\n", " else:\n", " raise ValueError(\"No message found in input\")\n", " outputs = []\n", " for tool_call in message.tool_calls:\n", " tool_result = self.tools_by_name[tool_call[\"name\"]].invoke(\n", " tool_call[\"args\"]\n", " )\n", " outputs.append(\n", " ToolMessage(\n", " content=json.dumps(tool_result),\n", " name=tool_call[\"name\"],\n", " tool_call_id=tool_call[\"id\"],\n", " )\n", " )\n", " return {\"messages\": outputs}\n", "\n", " def basic_tools_conditions(state: State):\n", " \"\"\"\n", " Use in the conditional_edge to route to the ToolNode if the last message\n", " has tool calls. Otherwise, route to the end.\n", " \"\"\"\n", " if isinstance(state, list):\n", " ai_message = state[-1]\n", " elif messages := state.get(\"messages\", []):\n", " ai_message = messages[-1]\n", " else:\n", " raise ValueError(f\"No messages found in input state to tool_edge: {state}\")\n", " print('route_tools: {ai_message}')\n", " if hasattr(ai_message, \"tool_calls\") and len(ai_message.tool_calls) > 0:\n", " routing_decision = \"tools\"\n", " else:\n", " routing_decision = END\n", " return routing_decision\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Define nodes:\n", "graph_builder.add_node(\"chatbot\", chatbot)\n", "# tool_node = BasicToolNode(tools=get_tools(verbose=False))\n", "tool_node = ToolNode(tools=get_tools(verbose=False))\n", "graph_builder.add_node(\"tools\", tool_node)\n", "\n", "# Define edges:\n", "# Entry Point\n", "graph_builder.add_edge(START, \"chatbot\")\n", "# Conditional Edge between the chatbot and the tool node\n", "# The `route_tools` function returns \"tools\" if the chatbot asks to use a tool, and \"END\" if\n", "# it is fine directly responding. This conditional routing defines the main agent loop.\n", "graph_builder.add_conditional_edges(\n", " \"chatbot\",\n", " # basic_tools_condition,\n", " tools_condition\n", " # The following dictionary lets you tell the graph to interpret the condition's outputs as a specific node\n", " # It defaults to the identity function, but if you want to use a node named something else apart from \"tools\",\n", " # You can update the value of the dictionary to something else e.g., \"tools\": \"my_tools\"\n", " # {\"tools\": \"tools\", END: END},\n", ")\n", "# Edge between the tool node and the chatbot\n", "# Any time a tool is called, we return to the chatbot to decide the next step\n", "graph_builder.add_edge(\"tools\", \"chatbot\")\n", "\n", "memory = MemorySaver()\n", "graph = graph_builder.compile(checkpointer=memory)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from IPython.display import Image, display\n", "\n", "try:\n", " display(Image(graph.get_graph().draw_mermaid_png()))\n", "except Exception:\n", " # This requires some extra dependencies and is optional\n", " pass" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "config = {\"configurable\": {\"thread_id\": \"1\"}}\n", "\n", "def stream_graph_updates(user_input: str):\n", " events= graph.stream(\n", " {\"messages\": [(\"user\", user_input)]},\n", " config,\n", " stream_mode=\"values\"\n", " )\n", " for event in events:\n", " event[\"messages\"][-1].pretty_print()\n", " #for value in event.values():\n", " # print(\"Assistant:\", value[\"messages\"][-1].content)\n", "\n", "\n", "while True:\n", " try:\n", " user_input = input(\"User: \")\n", " if user_input.lower() in [\"quit\", \"exit\", \"q\"]:\n", " print(\"Goodbye!\")\n", " break\n", " snapshot = graph.get_state(config)\n", " print(f'Current state: {snapshot}')\n", " stream_graph_updates(user_input)\n", " except Exception as e:\n", " # fallback if input() is not available\n", " raise Exception(f'An error occured: {e}')" ] } ], "metadata": { "kernelspec": { "display_name": "langchain_311", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.1" } }, "nbformat": 4, "nbformat_minor": 2 }