From 7082e63b115d72440ee2dfe3f545fa3dcba490d5 Mon Sep 17 00:00:00 2001 From: Reinier van der Leer Date: Tue, 16 Apr 2024 10:38:49 +0200 Subject: [PATCH] refactor(agent): Refactor & improve `create_chat_completion` (#7082) * refactor(agent/core): Rearrange and split up `OpenAIProvider.create_chat_completion` - Rearrange to reduce complexity, improve separation/abstraction of concerns, and allow multiple points of failure during parsing - Move conversion from `ChatMessage` to `openai.types.ChatCompletionMessageParam` to `_get_chat_completion_args` - Move token usage and cost tracking boilerplate code to `_create_chat_completion` - Move tool call conversion/parsing to `_parse_assistant_tool_calls` (new) * fix(agent/core): Handle decoding of function call arguments in `create_chat_completion` - Amend `model_providers.schema`: change type of `arguments` from `str` to `dict[str, Any]` on `AssistantFunctionCall` and `AssistantFunctionCallDict` - Implement robust and transparent parsing in `OpenAIProvider._parse_assistant_tool_calls` - Remove now unnecessary `json_loads` calls throughout codebase * feat(agent/utils): Improve conditions and errors in `json_loads` - Include all decoding errors when raising a ValueError on decode failure - Use errors returned by `return_errors` instead of an error buffer - Fix check for decode failure --- .../agent_factory/profile_generator.py | 5 +- .../agents/prompt_strategies/one_shot.py | 4 +- .../prompt_strategies/initial_plan.py | 5 +- .../prompt_strategies/name_and_goals.py | 5 +- .../prompt_strategies/next_ability.py | 5 +- .../core/resource/model_providers/openai.py | 279 ++++++++++++------ .../core/resource/model_providers/schema.py | 5 +- .../autogpt/autogpt/core/utils/json_utils.py | 19 +- 8 files changed, 211 insertions(+), 116 deletions(-) diff --git a/autogpts/autogpt/autogpt/agent_factory/profile_generator.py b/autogpts/autogpt/autogpt/agent_factory/profile_generator.py index 889b7f2d4..78afbe51a 100644 --- a/autogpts/autogpt/autogpt/agent_factory/profile_generator.py +++ b/autogpts/autogpt/autogpt/agent_factory/profile_generator.py @@ -15,7 +15,6 @@ from autogpt.core.resource.model_providers.schema import ( CompletionModelFunction, ) from autogpt.core.utils.json_schema import JSONSchema -from autogpt.core.utils.json_utils import json_loads logger = logging.getLogger(__name__) @@ -203,9 +202,7 @@ class AgentProfileGenerator(PromptStrategy): f"LLM did not call {self._create_agent_function.name} function; " "agent profile creation failed" ) - arguments: object = json_loads( - response_content.tool_calls[0].function.arguments - ) + arguments: object = response_content.tool_calls[0].function.arguments ai_profile = AIProfile( ai_name=arguments.get("name"), ai_role=arguments.get("description"), diff --git a/autogpts/autogpt/autogpt/agents/prompt_strategies/one_shot.py b/autogpts/autogpt/autogpt/agents/prompt_strategies/one_shot.py index 0234c59a5..994df6181 100644 --- a/autogpts/autogpt/autogpt/agents/prompt_strategies/one_shot.py +++ b/autogpts/autogpt/autogpt/agents/prompt_strategies/one_shot.py @@ -26,7 +26,7 @@ from autogpt.core.resource.model_providers.schema import ( CompletionModelFunction, ) from autogpt.core.utils.json_schema import JSONSchema -from autogpt.core.utils.json_utils import extract_dict_from_json, json_loads +from autogpt.core.utils.json_utils import extract_dict_from_json from autogpt.prompts.utils import format_numbered_list, indent @@ -436,7 +436,7 @@ def extract_command( raise InvalidAgentResponseError("No 'tool_calls' in assistant reply") assistant_reply_json["command"] = { "name": assistant_reply.tool_calls[0].function.name, - "args": json_loads(assistant_reply.tool_calls[0].function.arguments), + "args": assistant_reply.tool_calls[0].function.arguments, } try: if not isinstance(assistant_reply_json, dict): diff --git a/autogpts/autogpt/autogpt/core/planning/prompt_strategies/initial_plan.py b/autogpts/autogpt/autogpt/core/planning/prompt_strategies/initial_plan.py index d26d86fd6..ae137a985 100644 --- a/autogpts/autogpt/autogpt/core/planning/prompt_strategies/initial_plan.py +++ b/autogpts/autogpt/autogpt/core/planning/prompt_strategies/initial_plan.py @@ -11,7 +11,6 @@ from autogpt.core.resource.model_providers import ( CompletionModelFunction, ) from autogpt.core.utils.json_schema import JSONSchema -from autogpt.core.utils.json_utils import json_loads logger = logging.getLogger(__name__) @@ -195,9 +194,7 @@ class InitialPlan(PromptStrategy): f"LLM did not call {self._create_plan_function.name} function; " "plan creation failed" ) - parsed_response: object = json_loads( - response_content.tool_calls[0].function.arguments - ) + parsed_response: object = response_content.tool_calls[0].function.arguments parsed_response["task_list"] = [ Task.parse_obj(task) for task in parsed_response["task_list"] ] diff --git a/autogpts/autogpt/autogpt/core/planning/prompt_strategies/name_and_goals.py b/autogpts/autogpt/autogpt/core/planning/prompt_strategies/name_and_goals.py index d030c05e1..133b4590d 100644 --- a/autogpts/autogpt/autogpt/core/planning/prompt_strategies/name_and_goals.py +++ b/autogpts/autogpt/autogpt/core/planning/prompt_strategies/name_and_goals.py @@ -9,7 +9,6 @@ from autogpt.core.resource.model_providers import ( CompletionModelFunction, ) from autogpt.core.utils.json_schema import JSONSchema -from autogpt.core.utils.json_utils import json_loads logger = logging.getLogger(__name__) @@ -141,9 +140,7 @@ class NameAndGoals(PromptStrategy): f"LLM did not call {self._create_agent_function} function; " "agent profile creation failed" ) - parsed_response = json_loads( - response_content.tool_calls[0].function.arguments - ) + parsed_response = response_content.tool_calls[0].function.arguments except KeyError: logger.debug(f"Failed to parse this response content: {response_content}") raise diff --git a/autogpts/autogpt/autogpt/core/planning/prompt_strategies/next_ability.py b/autogpts/autogpt/autogpt/core/planning/prompt_strategies/next_ability.py index dec67c295..0d6daad2e 100644 --- a/autogpts/autogpt/autogpt/core/planning/prompt_strategies/next_ability.py +++ b/autogpts/autogpt/autogpt/core/planning/prompt_strategies/next_ability.py @@ -11,7 +11,6 @@ from autogpt.core.resource.model_providers import ( CompletionModelFunction, ) from autogpt.core.utils.json_schema import JSONSchema -from autogpt.core.utils.json_utils import json_loads logger = logging.getLogger(__name__) @@ -188,9 +187,7 @@ class NextAbility(PromptStrategy): raise ValueError("LLM did not call any function") function_name = response_content.tool_calls[0].function.name - function_arguments = json_loads( - response_content.tool_calls[0].function.arguments - ) + function_arguments = response_content.tool_calls[0].function.arguments parsed_response = { "motivation": function_arguments.pop("motivation"), "self_criticism": function_arguments.pop("self_criticism"), diff --git a/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py b/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py index d68254a9c..cd01b496a 100644 --- a/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py +++ b/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py @@ -3,7 +3,7 @@ import logging import math import os from pathlib import Path -from typing import Callable, Coroutine, Iterator, Optional, ParamSpec, TypeVar +from typing import Any, Callable, Coroutine, Iterator, Optional, ParamSpec, TypeVar import sentry_sdk import tenacity @@ -11,12 +11,17 @@ import tiktoken import yaml from openai._exceptions import APIStatusError, RateLimitError from openai.types import CreateEmbeddingResponse -from openai.types.chat import ChatCompletion +from openai.types.chat import ( + ChatCompletion, + ChatCompletionMessage, + ChatCompletionMessageParam, +) from pydantic import SecretStr from autogpt.core.configuration import Configurable, UserConfigurable from autogpt.core.resource.model_providers.schema import ( AssistantChatMessage, + AssistantFunctionCall, AssistantToolCall, AssistantToolCallDict, ChatMessage, @@ -406,83 +411,90 @@ class OpenAIProvider( ) -> ChatModelResponse[_T]: """Create a completion using the OpenAI API.""" - completion_kwargs = self._get_completion_kwargs(model_name, functions, **kwargs) - tool_calls_compat_mode = functions and "tools" not in completion_kwargs - if "messages" in completion_kwargs: - model_prompt += completion_kwargs["messages"] - del completion_kwargs["messages"] + openai_messages, completion_kwargs = self._get_chat_completion_args( + model_prompt, model_name, functions, **kwargs + ) + tool_calls_compat_mode = bool(functions and "tools" not in completion_kwargs) - cost = 0.0 + total_cost = 0.0 attempts = 0 while True: - _response = await self._create_chat_completion( - messages=model_prompt, + _response, _cost, t_input, t_output = await self._create_chat_completion( + messages=openai_messages, **completion_kwargs, ) - - _assistant_msg = _response.choices[0].message - assistant_msg = AssistantChatMessage( - content=_assistant_msg.content, - tool_calls=( - [AssistantToolCall(**tc.dict()) for tc in _assistant_msg.tool_calls] - if _assistant_msg.tool_calls - else None - ), - ) - response = ChatModelResponse( - response=assistant_msg, - model_info=OPEN_AI_CHAT_MODELS[model_name], - prompt_tokens_used=( - _response.usage.prompt_tokens if _response.usage else 0 - ), - completion_tokens_used=( - _response.usage.completion_tokens if _response.usage else 0 - ), - ) - cost += self._budget.update_usage_and_cost(response) - self._logger.debug( - f"Completion usage: {response.prompt_tokens_used} input, " - f"{response.completion_tokens_used} output - ${round(cost, 5)}" - ) + total_cost += _cost # If parsing the response fails, append the error to the prompt, and let the # LLM fix its mistake(s). - try: - attempts += 1 + attempts += 1 + parse_errors: list[Exception] = [] - if ( - tool_calls_compat_mode - and assistant_msg.content - and not assistant_msg.tool_calls - ): - assistant_msg.tool_calls = list( - _tool_calls_compat_extract_calls(assistant_msg.content) + _assistant_msg = _response.choices[0].message + + tool_calls, _errors = self._parse_assistant_tool_calls( + _assistant_msg, tool_calls_compat_mode + ) + parse_errors += _errors + + assistant_msg = AssistantChatMessage( + content=_assistant_msg.content, + tool_calls=tool_calls or None, + ) + + parsed_result: _T = None # type: ignore + if not parse_errors: + try: + parsed_result = completion_parser(assistant_msg) + except Exception as e: + parse_errors.append(e) + + if not parse_errors: + if attempts > 1: + self._logger.debug( + f"Total cost for {attempts} attempts: ${round(total_cost, 5)}" ) - response.parsed_result = completion_parser(assistant_msg) - break - except Exception as e: - self._logger.warning(f"Parsing attempt #{attempts} failed: {e}") - self._logger.debug(f"Parsing failed on response: '''{assistant_msg}'''") - sentry_sdk.capture_exception( - error=e, - extras={"assistant_msg": assistant_msg, "i_attempt": attempts}, + return ChatModelResponse( + response=AssistantChatMessage( + content=_assistant_msg.content, + tool_calls=tool_calls or None, + ), + parsed_result=parsed_result, + model_info=OPEN_AI_CHAT_MODELS[model_name], + prompt_tokens_used=t_input, + completion_tokens_used=t_output, ) - if attempts < self._configuration.fix_failed_parse_tries: - model_prompt.append(assistant_msg) - model_prompt.append( - ChatMessage.system( - "ERROR PARSING YOUR RESPONSE:\n\n" - f"{e.__class__.__name__}: {e}" - ) + + else: + self._logger.debug( + f"Parsing failed on response: '''{_assistant_msg}'''" + ) + self._logger.warning( + f"Parsing attempt #{attempts} failed: {parse_errors}" + ) + for e in parse_errors: + sentry_sdk.capture_exception( + error=e, + extras={"assistant_msg": _assistant_msg, "i_attempt": attempts}, ) + + if attempts < self._configuration.fix_failed_parse_tries: + openai_messages.append(_assistant_msg.dict(exclude_none=True)) + openai_messages.append( + { + "role": "system", + "content": ( + "ERROR PARSING YOUR RESPONSE:\n\n" + + "\n\n".join( + f"{e.__class__.__name__}: {e}" for e in parse_errors + ) + ), + } + ) + continue else: - raise - - if attempts > 1: - self._logger.debug(f"Total cost for {attempts} attempts: ${round(cost, 5)}") - - return response + raise parse_errors[0] async def create_embedding( self, @@ -504,21 +516,24 @@ class OpenAIProvider( self._budget.update_usage_and_cost(response) return response - def _get_completion_kwargs( + def _get_chat_completion_args( self, + model_prompt: list[ChatMessage], model_name: OpenAIModelName, functions: Optional[list[CompletionModelFunction]] = None, **kwargs, - ) -> dict: - """Get kwargs for completion API call. + ) -> tuple[list[ChatCompletionMessageParam], dict[str, Any]]: + """Prepare chat completion arguments and keyword arguments for API call. Args: - model: The model to use. - kwargs: Keyword arguments to override the default values. + model_prompt: List of ChatMessages. + model_name: The model to use. + functions: Optional list of functions available to the LLM. + kwargs: Additional keyword arguments. Returns: - The kwargs for the chat API call. - + list[ChatCompletionMessageParam]: Prompt messages for the OpenAI call + dict[str, Any]: Any other kwargs for the OpenAI call """ kwargs.update(self._credentials.get_model_access_kwargs(model_name)) @@ -541,7 +556,19 @@ class OpenAIProvider( kwargs["extra_headers"] = kwargs.get("extra_headers", {}) kwargs["extra_headers"].update(extra_headers.copy()) - return kwargs + if "messages" in kwargs: + model_prompt += kwargs["messages"] + del kwargs["messages"] + + openai_messages: list[ChatCompletionMessageParam] = [ + message.dict( + include={"role", "content", "tool_calls", "name"}, + exclude_none=True, + ) + for message in model_prompt + ] + + return openai_messages, kwargs def _get_embedding_kwargs( self, @@ -566,28 +593,106 @@ class OpenAIProvider( return kwargs - def _create_chat_completion( - self, messages: list[ChatMessage], *_, **kwargs - ) -> Coroutine[None, None, ChatCompletion]: - """Create a chat completion using the OpenAI API with retry handling.""" + async def _create_chat_completion( + self, + messages: list[ChatCompletionMessageParam], + model: OpenAIModelName, + *_, + **kwargs, + ) -> tuple[ChatCompletion, float, int, int]: + """ + Create a chat completion using the OpenAI API with retry handling. + + Params: + openai_messages: List of OpenAI-consumable message dict objects + model: The model to use for the completion + + Returns: + ChatCompletion: The chat completion response object + float: The cost ($) of this completion + int: Number of prompt tokens used + int: Number of completion tokens used + """ @self._retry_api_request async def _create_chat_completion_with_retry( - messages: list[ChatMessage], *_, **kwargs + messages: list[ChatCompletionMessageParam], **kwargs ) -> ChatCompletion: - raw_messages = [ - message.dict( - include={"role", "content", "tool_calls", "name"}, - exclude_none=True, - ) - for message in messages - ] return await self._client.chat.completions.create( - messages=raw_messages, # type: ignore + messages=messages, # type: ignore **kwargs, ) - return _create_chat_completion_with_retry(messages, *_, **kwargs) + completion = await _create_chat_completion_with_retry( + messages, model=model, **kwargs + ) + + if completion.usage: + prompt_tokens_used = completion.usage.prompt_tokens + completion_tokens_used = completion.usage.completion_tokens + else: + prompt_tokens_used = completion_tokens_used = 0 + + cost = self._budget.update_usage_and_cost( + ChatModelResponse( + response=AssistantChatMessage(content=None), + model_info=OPEN_AI_CHAT_MODELS[model], + prompt_tokens_used=prompt_tokens_used, + completion_tokens_used=completion_tokens_used, + ) + ) + self._logger.debug( + f"Completion usage: {prompt_tokens_used} input, " + f"{completion_tokens_used} output - ${round(cost, 5)}" + ) + return completion, cost, prompt_tokens_used, completion_tokens_used + + def _parse_assistant_tool_calls( + self, assistant_message: ChatCompletionMessage, compat_mode: bool = False + ): + tool_calls: list[AssistantToolCall] = [] + parse_errors: list[Exception] = [] + + if assistant_message.tool_calls: + for _tc in assistant_message.tool_calls: + try: + parsed_arguments = json_loads(_tc.function.arguments) + except Exception as e: + err_message = ( + f"Decoding arguments for {_tc.function.name} failed: " + + str(e.args[0]) + ) + parse_errors.append( + type(e)(err_message, *e.args[1:]).with_traceback( + e.__traceback__ + ) + ) + continue + + tool_calls.append( + AssistantToolCall( + id=_tc.id, + type=_tc.type, + function=AssistantFunctionCall( + name=_tc.function.name, + arguments=parsed_arguments, + ), + ) + ) + + # If parsing of all tool calls succeeds in the end, we ignore any issues + if len(tool_calls) == len(assistant_message.tool_calls): + parse_errors = [] + + elif compat_mode and assistant_message.content: + try: + tool_calls = list( + _tool_calls_compat_extract_calls(assistant_message.content) + ) + except Exception as e: + parse_errors.append(e) + + return tool_calls, parse_errors def _create_embedding( self, text: str, *_, **kwargs diff --git a/autogpts/autogpt/autogpt/core/resource/model_providers/schema.py b/autogpts/autogpt/autogpt/core/resource/model_providers/schema.py index 43d4bd296..cc0030995 100644 --- a/autogpts/autogpt/autogpt/core/resource/model_providers/schema.py +++ b/autogpts/autogpt/autogpt/core/resource/model_providers/schema.py @@ -2,6 +2,7 @@ import abc import enum import math from typing import ( + Any, Callable, ClassVar, Generic, @@ -68,12 +69,12 @@ class ChatMessageDict(TypedDict): class AssistantFunctionCall(BaseModel): name: str - arguments: str + arguments: dict[str, Any] class AssistantFunctionCallDict(TypedDict): name: str - arguments: str + arguments: dict[str, Any] class AssistantToolCall(BaseModel): diff --git a/autogpts/autogpt/autogpt/core/utils/json_utils.py b/autogpts/autogpt/autogpt/core/utils/json_utils.py index 664cb87c1..0374a85c1 100644 --- a/autogpts/autogpt/autogpt/core/utils/json_utils.py +++ b/autogpts/autogpt/autogpt/core/utils/json_utils.py @@ -1,4 +1,3 @@ -import io import logging import re from typing import Any @@ -32,16 +31,18 @@ def json_loads(json_str: str) -> Any: if match: json_str = match.group(1).strip() - error_buffer = io.StringIO() - json_result = demjson3.decode( - json_str, return_errors=True, write_errors=error_buffer - ) + json_result = demjson3.decode(json_str, return_errors=True) + assert json_result is not None # by virtue of return_errors=True - if error_buffer.getvalue(): - logger.debug(f"JSON parse errors:\n{error_buffer.getvalue()}") + if json_result.errors: + logger.debug( + "JSON parse errors:\n" + "\n".join(str(e) for e in json_result.errors) + ) - if json_result is None: - raise ValueError(f"Failed to parse JSON string: {json_str}") + if json_result.object is demjson3.undefined: + raise ValueError( + f"Failed to parse JSON string: {json_str}", *json_result.errors + ) return json_result.object