refactor(agent): Refactor & improve `create_chat_completion` (#7082)

* refactor(agent/core): Rearrange and split up `OpenAIProvider.create_chat_completion`
   - Rearrange to reduce complexity, improve separation/abstraction of concerns, and allow multiple points of failure during parsing
   - Move conversion from `ChatMessage` to `openai.types.ChatCompletionMessageParam` to `_get_chat_completion_args`
   - Move token usage and cost tracking boilerplate code to `_create_chat_completion`
   - Move tool call conversion/parsing to `_parse_assistant_tool_calls` (new)

* fix(agent/core): Handle decoding of function call arguments in `create_chat_completion`
   - Amend `model_providers.schema`: change type of `arguments` from `str` to `dict[str, Any]` on `AssistantFunctionCall` and `AssistantFunctionCallDict`
   - Implement robust and transparent parsing in `OpenAIProvider._parse_assistant_tool_calls`
   - Remove now unnecessary `json_loads` calls throughout codebase

* feat(agent/utils): Improve conditions and errors in `json_loads`
   - Include all decoding errors when raising a ValueError on decode failure
   - Use errors returned by `return_errors` instead of an error buffer
   - Fix check for decode failure
pull/7093/head
Reinier van der Leer 2024-04-16 10:38:49 +02:00 committed by GitHub
parent d7f00a996f
commit 7082e63b11
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 211 additions and 116 deletions

View File

@ -15,7 +15,6 @@ from autogpt.core.resource.model_providers.schema import (
CompletionModelFunction,
)
from autogpt.core.utils.json_schema import JSONSchema
from autogpt.core.utils.json_utils import json_loads
logger = logging.getLogger(__name__)
@ -203,9 +202,7 @@ class AgentProfileGenerator(PromptStrategy):
f"LLM did not call {self._create_agent_function.name} function; "
"agent profile creation failed"
)
arguments: object = json_loads(
response_content.tool_calls[0].function.arguments
)
arguments: object = response_content.tool_calls[0].function.arguments
ai_profile = AIProfile(
ai_name=arguments.get("name"),
ai_role=arguments.get("description"),

View File

@ -26,7 +26,7 @@ from autogpt.core.resource.model_providers.schema import (
CompletionModelFunction,
)
from autogpt.core.utils.json_schema import JSONSchema
from autogpt.core.utils.json_utils import extract_dict_from_json, json_loads
from autogpt.core.utils.json_utils import extract_dict_from_json
from autogpt.prompts.utils import format_numbered_list, indent
@ -436,7 +436,7 @@ def extract_command(
raise InvalidAgentResponseError("No 'tool_calls' in assistant reply")
assistant_reply_json["command"] = {
"name": assistant_reply.tool_calls[0].function.name,
"args": json_loads(assistant_reply.tool_calls[0].function.arguments),
"args": assistant_reply.tool_calls[0].function.arguments,
}
try:
if not isinstance(assistant_reply_json, dict):

View File

@ -11,7 +11,6 @@ from autogpt.core.resource.model_providers import (
CompletionModelFunction,
)
from autogpt.core.utils.json_schema import JSONSchema
from autogpt.core.utils.json_utils import json_loads
logger = logging.getLogger(__name__)
@ -195,9 +194,7 @@ class InitialPlan(PromptStrategy):
f"LLM did not call {self._create_plan_function.name} function; "
"plan creation failed"
)
parsed_response: object = json_loads(
response_content.tool_calls[0].function.arguments
)
parsed_response: object = response_content.tool_calls[0].function.arguments
parsed_response["task_list"] = [
Task.parse_obj(task) for task in parsed_response["task_list"]
]

View File

@ -9,7 +9,6 @@ from autogpt.core.resource.model_providers import (
CompletionModelFunction,
)
from autogpt.core.utils.json_schema import JSONSchema
from autogpt.core.utils.json_utils import json_loads
logger = logging.getLogger(__name__)
@ -141,9 +140,7 @@ class NameAndGoals(PromptStrategy):
f"LLM did not call {self._create_agent_function} function; "
"agent profile creation failed"
)
parsed_response = json_loads(
response_content.tool_calls[0].function.arguments
)
parsed_response = response_content.tool_calls[0].function.arguments
except KeyError:
logger.debug(f"Failed to parse this response content: {response_content}")
raise

View File

@ -11,7 +11,6 @@ from autogpt.core.resource.model_providers import (
CompletionModelFunction,
)
from autogpt.core.utils.json_schema import JSONSchema
from autogpt.core.utils.json_utils import json_loads
logger = logging.getLogger(__name__)
@ -188,9 +187,7 @@ class NextAbility(PromptStrategy):
raise ValueError("LLM did not call any function")
function_name = response_content.tool_calls[0].function.name
function_arguments = json_loads(
response_content.tool_calls[0].function.arguments
)
function_arguments = response_content.tool_calls[0].function.arguments
parsed_response = {
"motivation": function_arguments.pop("motivation"),
"self_criticism": function_arguments.pop("self_criticism"),

View File

@ -3,7 +3,7 @@ import logging
import math
import os
from pathlib import Path
from typing import Callable, Coroutine, Iterator, Optional, ParamSpec, TypeVar
from typing import Any, Callable, Coroutine, Iterator, Optional, ParamSpec, TypeVar
import sentry_sdk
import tenacity
@ -11,12 +11,17 @@ import tiktoken
import yaml
from openai._exceptions import APIStatusError, RateLimitError
from openai.types import CreateEmbeddingResponse
from openai.types.chat import ChatCompletion
from openai.types.chat import (
ChatCompletion,
ChatCompletionMessage,
ChatCompletionMessageParam,
)
from pydantic import SecretStr
from autogpt.core.configuration import Configurable, UserConfigurable
from autogpt.core.resource.model_providers.schema import (
AssistantChatMessage,
AssistantFunctionCall,
AssistantToolCall,
AssistantToolCallDict,
ChatMessage,
@ -406,83 +411,90 @@ class OpenAIProvider(
) -> ChatModelResponse[_T]:
"""Create a completion using the OpenAI API."""
completion_kwargs = self._get_completion_kwargs(model_name, functions, **kwargs)
tool_calls_compat_mode = functions and "tools" not in completion_kwargs
if "messages" in completion_kwargs:
model_prompt += completion_kwargs["messages"]
del completion_kwargs["messages"]
openai_messages, completion_kwargs = self._get_chat_completion_args(
model_prompt, model_name, functions, **kwargs
)
tool_calls_compat_mode = bool(functions and "tools" not in completion_kwargs)
cost = 0.0
total_cost = 0.0
attempts = 0
while True:
_response = await self._create_chat_completion(
messages=model_prompt,
_response, _cost, t_input, t_output = await self._create_chat_completion(
messages=openai_messages,
**completion_kwargs,
)
_assistant_msg = _response.choices[0].message
assistant_msg = AssistantChatMessage(
content=_assistant_msg.content,
tool_calls=(
[AssistantToolCall(**tc.dict()) for tc in _assistant_msg.tool_calls]
if _assistant_msg.tool_calls
else None
),
)
response = ChatModelResponse(
response=assistant_msg,
model_info=OPEN_AI_CHAT_MODELS[model_name],
prompt_tokens_used=(
_response.usage.prompt_tokens if _response.usage else 0
),
completion_tokens_used=(
_response.usage.completion_tokens if _response.usage else 0
),
)
cost += self._budget.update_usage_and_cost(response)
self._logger.debug(
f"Completion usage: {response.prompt_tokens_used} input, "
f"{response.completion_tokens_used} output - ${round(cost, 5)}"
)
total_cost += _cost
# If parsing the response fails, append the error to the prompt, and let the
# LLM fix its mistake(s).
try:
attempts += 1
attempts += 1
parse_errors: list[Exception] = []
if (
tool_calls_compat_mode
and assistant_msg.content
and not assistant_msg.tool_calls
):
assistant_msg.tool_calls = list(
_tool_calls_compat_extract_calls(assistant_msg.content)
_assistant_msg = _response.choices[0].message
tool_calls, _errors = self._parse_assistant_tool_calls(
_assistant_msg, tool_calls_compat_mode
)
parse_errors += _errors
assistant_msg = AssistantChatMessage(
content=_assistant_msg.content,
tool_calls=tool_calls or None,
)
parsed_result: _T = None # type: ignore
if not parse_errors:
try:
parsed_result = completion_parser(assistant_msg)
except Exception as e:
parse_errors.append(e)
if not parse_errors:
if attempts > 1:
self._logger.debug(
f"Total cost for {attempts} attempts: ${round(total_cost, 5)}"
)
response.parsed_result = completion_parser(assistant_msg)
break
except Exception as e:
self._logger.warning(f"Parsing attempt #{attempts} failed: {e}")
self._logger.debug(f"Parsing failed on response: '''{assistant_msg}'''")
sentry_sdk.capture_exception(
error=e,
extras={"assistant_msg": assistant_msg, "i_attempt": attempts},
return ChatModelResponse(
response=AssistantChatMessage(
content=_assistant_msg.content,
tool_calls=tool_calls or None,
),
parsed_result=parsed_result,
model_info=OPEN_AI_CHAT_MODELS[model_name],
prompt_tokens_used=t_input,
completion_tokens_used=t_output,
)
if attempts < self._configuration.fix_failed_parse_tries:
model_prompt.append(assistant_msg)
model_prompt.append(
ChatMessage.system(
"ERROR PARSING YOUR RESPONSE:\n\n"
f"{e.__class__.__name__}: {e}"
)
else:
self._logger.debug(
f"Parsing failed on response: '''{_assistant_msg}'''"
)
self._logger.warning(
f"Parsing attempt #{attempts} failed: {parse_errors}"
)
for e in parse_errors:
sentry_sdk.capture_exception(
error=e,
extras={"assistant_msg": _assistant_msg, "i_attempt": attempts},
)
if attempts < self._configuration.fix_failed_parse_tries:
openai_messages.append(_assistant_msg.dict(exclude_none=True))
openai_messages.append(
{
"role": "system",
"content": (
"ERROR PARSING YOUR RESPONSE:\n\n"
+ "\n\n".join(
f"{e.__class__.__name__}: {e}" for e in parse_errors
)
),
}
)
continue
else:
raise
if attempts > 1:
self._logger.debug(f"Total cost for {attempts} attempts: ${round(cost, 5)}")
return response
raise parse_errors[0]
async def create_embedding(
self,
@ -504,21 +516,24 @@ class OpenAIProvider(
self._budget.update_usage_and_cost(response)
return response
def _get_completion_kwargs(
def _get_chat_completion_args(
self,
model_prompt: list[ChatMessage],
model_name: OpenAIModelName,
functions: Optional[list[CompletionModelFunction]] = None,
**kwargs,
) -> dict:
"""Get kwargs for completion API call.
) -> tuple[list[ChatCompletionMessageParam], dict[str, Any]]:
"""Prepare chat completion arguments and keyword arguments for API call.
Args:
model: The model to use.
kwargs: Keyword arguments to override the default values.
model_prompt: List of ChatMessages.
model_name: The model to use.
functions: Optional list of functions available to the LLM.
kwargs: Additional keyword arguments.
Returns:
The kwargs for the chat API call.
list[ChatCompletionMessageParam]: Prompt messages for the OpenAI call
dict[str, Any]: Any other kwargs for the OpenAI call
"""
kwargs.update(self._credentials.get_model_access_kwargs(model_name))
@ -541,7 +556,19 @@ class OpenAIProvider(
kwargs["extra_headers"] = kwargs.get("extra_headers", {})
kwargs["extra_headers"].update(extra_headers.copy())
return kwargs
if "messages" in kwargs:
model_prompt += kwargs["messages"]
del kwargs["messages"]
openai_messages: list[ChatCompletionMessageParam] = [
message.dict(
include={"role", "content", "tool_calls", "name"},
exclude_none=True,
)
for message in model_prompt
]
return openai_messages, kwargs
def _get_embedding_kwargs(
self,
@ -566,28 +593,106 @@ class OpenAIProvider(
return kwargs
def _create_chat_completion(
self, messages: list[ChatMessage], *_, **kwargs
) -> Coroutine[None, None, ChatCompletion]:
"""Create a chat completion using the OpenAI API with retry handling."""
async def _create_chat_completion(
self,
messages: list[ChatCompletionMessageParam],
model: OpenAIModelName,
*_,
**kwargs,
) -> tuple[ChatCompletion, float, int, int]:
"""
Create a chat completion using the OpenAI API with retry handling.
Params:
openai_messages: List of OpenAI-consumable message dict objects
model: The model to use for the completion
Returns:
ChatCompletion: The chat completion response object
float: The cost ($) of this completion
int: Number of prompt tokens used
int: Number of completion tokens used
"""
@self._retry_api_request
async def _create_chat_completion_with_retry(
messages: list[ChatMessage], *_, **kwargs
messages: list[ChatCompletionMessageParam], **kwargs
) -> ChatCompletion:
raw_messages = [
message.dict(
include={"role", "content", "tool_calls", "name"},
exclude_none=True,
)
for message in messages
]
return await self._client.chat.completions.create(
messages=raw_messages, # type: ignore
messages=messages, # type: ignore
**kwargs,
)
return _create_chat_completion_with_retry(messages, *_, **kwargs)
completion = await _create_chat_completion_with_retry(
messages, model=model, **kwargs
)
if completion.usage:
prompt_tokens_used = completion.usage.prompt_tokens
completion_tokens_used = completion.usage.completion_tokens
else:
prompt_tokens_used = completion_tokens_used = 0
cost = self._budget.update_usage_and_cost(
ChatModelResponse(
response=AssistantChatMessage(content=None),
model_info=OPEN_AI_CHAT_MODELS[model],
prompt_tokens_used=prompt_tokens_used,
completion_tokens_used=completion_tokens_used,
)
)
self._logger.debug(
f"Completion usage: {prompt_tokens_used} input, "
f"{completion_tokens_used} output - ${round(cost, 5)}"
)
return completion, cost, prompt_tokens_used, completion_tokens_used
def _parse_assistant_tool_calls(
self, assistant_message: ChatCompletionMessage, compat_mode: bool = False
):
tool_calls: list[AssistantToolCall] = []
parse_errors: list[Exception] = []
if assistant_message.tool_calls:
for _tc in assistant_message.tool_calls:
try:
parsed_arguments = json_loads(_tc.function.arguments)
except Exception as e:
err_message = (
f"Decoding arguments for {_tc.function.name} failed: "
+ str(e.args[0])
)
parse_errors.append(
type(e)(err_message, *e.args[1:]).with_traceback(
e.__traceback__
)
)
continue
tool_calls.append(
AssistantToolCall(
id=_tc.id,
type=_tc.type,
function=AssistantFunctionCall(
name=_tc.function.name,
arguments=parsed_arguments,
),
)
)
# If parsing of all tool calls succeeds in the end, we ignore any issues
if len(tool_calls) == len(assistant_message.tool_calls):
parse_errors = []
elif compat_mode and assistant_message.content:
try:
tool_calls = list(
_tool_calls_compat_extract_calls(assistant_message.content)
)
except Exception as e:
parse_errors.append(e)
return tool_calls, parse_errors
def _create_embedding(
self, text: str, *_, **kwargs

View File

@ -2,6 +2,7 @@ import abc
import enum
import math
from typing import (
Any,
Callable,
ClassVar,
Generic,
@ -68,12 +69,12 @@ class ChatMessageDict(TypedDict):
class AssistantFunctionCall(BaseModel):
name: str
arguments: str
arguments: dict[str, Any]
class AssistantFunctionCallDict(TypedDict):
name: str
arguments: str
arguments: dict[str, Any]
class AssistantToolCall(BaseModel):

View File

@ -1,4 +1,3 @@
import io
import logging
import re
from typing import Any
@ -32,16 +31,18 @@ def json_loads(json_str: str) -> Any:
if match:
json_str = match.group(1).strip()
error_buffer = io.StringIO()
json_result = demjson3.decode(
json_str, return_errors=True, write_errors=error_buffer
)
json_result = demjson3.decode(json_str, return_errors=True)
assert json_result is not None # by virtue of return_errors=True
if error_buffer.getvalue():
logger.debug(f"JSON parse errors:\n{error_buffer.getvalue()}")
if json_result.errors:
logger.debug(
"JSON parse errors:\n" + "\n".join(str(e) for e in json_result.errors)
)
if json_result is None:
raise ValueError(f"Failed to parse JSON string: {json_str}")
if json_result.object is demjson3.undefined:
raise ValueError(
f"Failed to parse JSON string: {json_str}", *json_result.errors
)
return json_result.object