11"""Module implements an agent that uses OpenAI's APIs function enabled API."""
22import json
3- from dataclasses import dataclass
43from json import JSONDecodeError
54from typing import Any , List , Optional , Sequence , Tuple , Union
65
76from langchain .agents import BaseSingleActionAgent
7+ from langchain .agents .agent import AgentOutputParser
8+ from langchain .agents .format_scratchpad .openai_functions import (
9+ format_to_openai_functions ,
10+ )
811from langchain .callbacks .base import BaseCallbackManager
912from langchain .callbacks .manager import Callbacks # type: ignore
1013from langchain .chat_models .openai import ChatOpenAI
1821from langchain .schema import (
1922 AgentAction ,
2023 AgentFinish ,
21- BasePromptTemplate ,
22- OutputParserException ,
23- )
24- from langchain .schema .language_model import BaseLanguageModel
25- from langchain .schema .messages import (
2624 AIMessage ,
2725 BaseMessage ,
28- FunctionMessage ,
26+ BasePromptTemplate ,
27+ OutputParserException ,
2928 SystemMessage ,
3029)
31- from langchain .tools import BaseTool
32- from langchain .tools .convert_to_openai import format_tool_to_openai_function
33-
34-
35- @dataclass
36- class _FunctionsAgentAction (AgentAction ):
37- message_log : List [BaseMessage ]
30+ from langchain .schema .agent import AgentActionMessageLog
31+ from langchain .schema .language_model import BaseLanguageModel
32+ from langchain .schema .output import ChatGeneration , Generation
33+ from langchain .tools .base import BaseTool
34+ from langchain .tools .render import format_tool_to_openai_function
3835
3936
40- def _convert_agent_action_to_messages (
41- agent_action : AgentAction , observation : str
42- ) -> List [BaseMessage ]:
43- """Convert an agent action to a message.
37+ class OpenAIFunctionsAgentOutputParser (AgentOutputParser ):
38+ """Parses a message into agent action/finish.
4439
45- This code is used to reconstruct the original AI message from the agent action.
40+ Is meant to be used with OpenAI models, as it relies on the specific
41+ function_call parameter from OpenAI to convey what tools to use.
4642
47- Args:
48- agent_action: Agent action to convert .
43+ If a function_call parameter is passed, then that is used to get
44+ the tool and tool input .
4945
50- Returns:
51- AIMessage that corresponds to the original tool invocation.
52- """
53- if isinstance (agent_action , _FunctionsAgentAction ):
54- return agent_action .message_log + [
55- _create_function_message (agent_action , observation )
56- ]
57- else :
58- return [AIMessage (content = agent_action .log )]
59-
60-
61- def _create_function_message (
62- agent_action : AgentAction , observation : str
63- ) -> FunctionMessage :
64- """Convert agent action and observation into a function message.
65- Args:
66- agent_action: the tool invocation request from the agent
67- observation: the result of the tool invocation
68- Returns:
69- FunctionMessage that corresponds to the original tool invocation
70- """
71- if not isinstance (observation , str ):
72- try :
73- content = json .dumps (observation , ensure_ascii = False )
74- except Exception :
75- content = str (observation )
76- else :
77- content = observation
78- return FunctionMessage (
79- name = agent_action .tool ,
80- content = content ,
81- )
82-
83-
84- def _format_intermediate_steps (
85- intermediate_steps : List [Tuple [AgentAction , str ]],
86- ) -> List [BaseMessage ]:
87- """Format intermediate steps.
88- Args:
89- intermediate_steps: Steps the LLM has taken to date, along with observations
90- Returns:
91- list of messages to send to the LLM for the next prediction
46+ If one is not passed, then the AIMessage is assumed to be the final output.
9247 """
93- messages = []
9448
95- for intermediate_step in intermediate_steps :
96- agent_action , observation = intermediate_step
97- messages .extend (_convert_agent_action_to_messages (agent_action , observation ))
98-
99- return messages
100-
101-
102- def _parse_ai_message (message : BaseMessage ) -> Union [AgentAction , AgentFinish ]:
103- """Parse an AI message."""
104- if not isinstance (message , AIMessage ):
105- raise TypeError (f"Expected an AI message got { type (message )} " )
106-
107- function_call = message .additional_kwargs .get ("function_call" , {})
108-
109- if function_call :
110- function_name = function_call ["name" ]
111- try :
112- _tool_input = json .loads (function_call ["arguments" ])
113- except JSONDecodeError :
114- if function_name == "python" :
115- code = function_call ["arguments" ]
116- _tool_input = {
117- "code" : code ,
118- }
49+ @property
50+ def _type (self ) -> str :
51+ return "openai-functions-agent"
52+
53+ @staticmethod
54+ def _parse_ai_message (message : BaseMessage ) -> Union [AgentAction , AgentFinish ]:
55+ """Parse an AI message."""
56+ if not isinstance (message , AIMessage ):
57+ raise TypeError (f"Expected an AI message got { type (message )} " )
58+
59+ function_call = message .additional_kwargs .get ("function_call" , {})
60+
61+ if function_call :
62+ function_name = function_call ["name" ]
63+ try :
64+ _tool_input = json .loads (function_call ["arguments" ])
65+ except JSONDecodeError :
66+ if function_name == "python" :
67+ code = function_call ["arguments" ]
68+ _tool_input = {
69+ "code" : code ,
70+ }
71+ else :
72+ raise OutputParserException (
73+ f"Could not parse tool input: { function_call } because "
74+ f"the `arguments` is not valid JSON."
75+ )
76+
77+ # HACK HACK HACK:
78+ # The code that encodes tool input into Open AI uses a special variable
79+ # name called `__arg1` to handle old style tools that do not expose a
80+ # schema and expect a single string argument as an input.
81+ # We unpack the argument here if it exists.
82+ # Open AI does not support passing in a JSON array as an argument.
83+ if "__arg1" in _tool_input :
84+ tool_input = _tool_input ["__arg1" ]
11985 else :
120- raise OutputParserException (
121- f"Could not parse tool input: { function_call } because "
122- f"the `arguments` is not valid JSON."
123- )
124-
125- # HACK HACK HACK:
126- # The code that encodes tool input into Open AI uses a special variable
127- # name called `__arg1` to handle old style tools that do not expose a
128- # schema and expect a single string argument as an input.
129- # We unpack the argument here if it exists.
130- # Open AI does not support passing in a JSON array as an argument.
131- if "__arg1" in _tool_input :
132- tool_input = _tool_input ["__arg1" ]
133- else :
134- tool_input = _tool_input
135-
136- content_msg = "responded: {content}\n " if message .content else "\n "
86+ tool_input = _tool_input
87+
88+ content_msg = f"responded: { message .content } \n " if message .content else "\n "
89+ log = f"\n Invoking: `{ function_name } ` with `{ tool_input } `\n { content_msg } \n "
90+ return AgentActionMessageLog (
91+ tool = function_name ,
92+ tool_input = tool_input ,
93+ log = log ,
94+ message_log = [message ],
95+ )
13796
138- return _FunctionsAgentAction (
139- tool = function_name ,
140- tool_input = tool_input ,
141- log = f"\n Invoking: `{ function_name } ` with `{ tool_input } `\n { content_msg } \n " ,
142- message_log = [message ],
97+ return AgentFinish (
98+ return_values = {"output" : message .content }, log = message .content
14399 )
144100
145- return AgentFinish (return_values = {"output" : message .content }, log = message .content )
101+ def parse_result (self , result : List [Generation ]) -> Union [AgentAction , AgentFinish ]:
102+ if not isinstance (result [0 ], ChatGeneration ):
103+ raise ValueError ("This output parser only works on ChatGeneration output" )
104+ message = result [0 ].message
105+ return self ._parse_ai_message (message )
106+
107+ def parse (self , text : str ) -> Union [AgentAction , AgentFinish ]:
108+ raise ValueError ("Can only parse messages" )
146109
147110
148111class OpenAIFunctionsAgent (BaseSingleActionAgent ):
@@ -206,7 +169,7 @@ def plan(
206169 Returns:
207170 Action specifying what tool to use.
208171 """
209- agent_scratchpad = _format_intermediate_steps (intermediate_steps )
172+ agent_scratchpad = format_to_openai_functions (intermediate_steps )
210173 selected_inputs = {
211174 k : kwargs [k ] for k in self .prompt .input_variables if k != "agent_scratchpad"
212175 }
@@ -224,7 +187,9 @@ def plan(
224187 messages ,
225188 callbacks = callbacks ,
226189 )
227- agent_decision = _parse_ai_message (predicted_message )
190+ agent_decision = OpenAIFunctionsAgentOutputParser ._parse_ai_message (
191+ predicted_message
192+ )
228193 return agent_decision
229194
230195 async def aplan (
@@ -243,7 +208,7 @@ async def aplan(
243208 Returns:
244209 Action specifying what tool to use.
245210 """
246- agent_scratchpad = _format_intermediate_steps (intermediate_steps )
211+ agent_scratchpad = format_to_openai_functions (intermediate_steps )
247212 selected_inputs = {
248213 k : kwargs [k ] for k in self .prompt .input_variables if k != "agent_scratchpad"
249214 }
@@ -253,7 +218,9 @@ async def aplan(
253218 predicted_message = await self .llm .apredict_messages (
254219 messages , functions = self .functions , callbacks = callbacks
255220 )
256- agent_decision = _parse_ai_message (predicted_message )
221+ agent_decision = OpenAIFunctionsAgentOutputParser ._parse_ai_message (
222+ predicted_message
223+ )
257224 return agent_decision
258225
259226 def return_stopped_response (
@@ -339,7 +306,7 @@ def from_llm_and_tools(
339306 extra_prompt_messages = extra_prompt_messages ,
340307 system_message = system_message ,
341308 )
342- return cls ( # type: ignore
309+ return cls (
343310 llm = llm ,
344311 prompt = prompt ,
345312 tools = tools ,
0 commit comments