refactor(agent): Refactor & improve `create_chat_completion` (#7082)
* refactor(agent/core): Rearrange and split up `OpenAIProvider.create_chat_completion` - Rearrange to reduce complexity, improve separation/abstraction of concerns, and allow multiple points of failure during parsing - Move conversion from `ChatMessage` to `openai.types.ChatCompletionMessageParam` to `_get_chat_completion_args` - Move token usage and cost tracking boilerplate code to `_create_chat_completion` - Move tool call conversion/parsing to `_parse_assistant_tool_calls` (new) * fix(agent/core): Handle decoding of function call arguments in `create_chat_completion` - Amend `model_providers.schema`: change type of `arguments` from `str` to `dict[str, Any]` on `AssistantFunctionCall` and `AssistantFunctionCallDict` - Implement robust and transparent parsing in `OpenAIProvider._parse_assistant_tool_calls` - Remove now unnecessary `json_loads` calls throughout codebase * feat(agent/utils): Improve conditions and errors in `json_loads` - Include all decoding errors when raising a ValueError on decode failure - Use errors returned by `return_errors` instead of an error buffer - Fix check for decode failurepull/7093/head
parent
d7f00a996f
commit
7082e63b11
|
@ -15,7 +15,6 @@ from autogpt.core.resource.model_providers.schema import (
|
|||
CompletionModelFunction,
|
||||
)
|
||||
from autogpt.core.utils.json_schema import JSONSchema
|
||||
from autogpt.core.utils.json_utils import json_loads
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -203,9 +202,7 @@ class AgentProfileGenerator(PromptStrategy):
|
|||
f"LLM did not call {self._create_agent_function.name} function; "
|
||||
"agent profile creation failed"
|
||||
)
|
||||
arguments: object = json_loads(
|
||||
response_content.tool_calls[0].function.arguments
|
||||
)
|
||||
arguments: object = response_content.tool_calls[0].function.arguments
|
||||
ai_profile = AIProfile(
|
||||
ai_name=arguments.get("name"),
|
||||
ai_role=arguments.get("description"),
|
||||
|
|
|
@ -26,7 +26,7 @@ from autogpt.core.resource.model_providers.schema import (
|
|||
CompletionModelFunction,
|
||||
)
|
||||
from autogpt.core.utils.json_schema import JSONSchema
|
||||
from autogpt.core.utils.json_utils import extract_dict_from_json, json_loads
|
||||
from autogpt.core.utils.json_utils import extract_dict_from_json
|
||||
from autogpt.prompts.utils import format_numbered_list, indent
|
||||
|
||||
|
||||
|
@ -436,7 +436,7 @@ def extract_command(
|
|||
raise InvalidAgentResponseError("No 'tool_calls' in assistant reply")
|
||||
assistant_reply_json["command"] = {
|
||||
"name": assistant_reply.tool_calls[0].function.name,
|
||||
"args": json_loads(assistant_reply.tool_calls[0].function.arguments),
|
||||
"args": assistant_reply.tool_calls[0].function.arguments,
|
||||
}
|
||||
try:
|
||||
if not isinstance(assistant_reply_json, dict):
|
||||
|
|
|
@ -11,7 +11,6 @@ from autogpt.core.resource.model_providers import (
|
|||
CompletionModelFunction,
|
||||
)
|
||||
from autogpt.core.utils.json_schema import JSONSchema
|
||||
from autogpt.core.utils.json_utils import json_loads
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -195,9 +194,7 @@ class InitialPlan(PromptStrategy):
|
|||
f"LLM did not call {self._create_plan_function.name} function; "
|
||||
"plan creation failed"
|
||||
)
|
||||
parsed_response: object = json_loads(
|
||||
response_content.tool_calls[0].function.arguments
|
||||
)
|
||||
parsed_response: object = response_content.tool_calls[0].function.arguments
|
||||
parsed_response["task_list"] = [
|
||||
Task.parse_obj(task) for task in parsed_response["task_list"]
|
||||
]
|
||||
|
|
|
@ -9,7 +9,6 @@ from autogpt.core.resource.model_providers import (
|
|||
CompletionModelFunction,
|
||||
)
|
||||
from autogpt.core.utils.json_schema import JSONSchema
|
||||
from autogpt.core.utils.json_utils import json_loads
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -141,9 +140,7 @@ class NameAndGoals(PromptStrategy):
|
|||
f"LLM did not call {self._create_agent_function} function; "
|
||||
"agent profile creation failed"
|
||||
)
|
||||
parsed_response = json_loads(
|
||||
response_content.tool_calls[0].function.arguments
|
||||
)
|
||||
parsed_response = response_content.tool_calls[0].function.arguments
|
||||
except KeyError:
|
||||
logger.debug(f"Failed to parse this response content: {response_content}")
|
||||
raise
|
||||
|
|
|
@ -11,7 +11,6 @@ from autogpt.core.resource.model_providers import (
|
|||
CompletionModelFunction,
|
||||
)
|
||||
from autogpt.core.utils.json_schema import JSONSchema
|
||||
from autogpt.core.utils.json_utils import json_loads
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -188,9 +187,7 @@ class NextAbility(PromptStrategy):
|
|||
raise ValueError("LLM did not call any function")
|
||||
|
||||
function_name = response_content.tool_calls[0].function.name
|
||||
function_arguments = json_loads(
|
||||
response_content.tool_calls[0].function.arguments
|
||||
)
|
||||
function_arguments = response_content.tool_calls[0].function.arguments
|
||||
parsed_response = {
|
||||
"motivation": function_arguments.pop("motivation"),
|
||||
"self_criticism": function_arguments.pop("self_criticism"),
|
||||
|
|
|
@ -3,7 +3,7 @@ import logging
|
|||
import math
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Callable, Coroutine, Iterator, Optional, ParamSpec, TypeVar
|
||||
from typing import Any, Callable, Coroutine, Iterator, Optional, ParamSpec, TypeVar
|
||||
|
||||
import sentry_sdk
|
||||
import tenacity
|
||||
|
@ -11,12 +11,17 @@ import tiktoken
|
|||
import yaml
|
||||
from openai._exceptions import APIStatusError, RateLimitError
|
||||
from openai.types import CreateEmbeddingResponse
|
||||
from openai.types.chat import ChatCompletion
|
||||
from openai.types.chat import (
|
||||
ChatCompletion,
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionMessageParam,
|
||||
)
|
||||
from pydantic import SecretStr
|
||||
|
||||
from autogpt.core.configuration import Configurable, UserConfigurable
|
||||
from autogpt.core.resource.model_providers.schema import (
|
||||
AssistantChatMessage,
|
||||
AssistantFunctionCall,
|
||||
AssistantToolCall,
|
||||
AssistantToolCallDict,
|
||||
ChatMessage,
|
||||
|
@ -406,83 +411,90 @@ class OpenAIProvider(
|
|||
) -> ChatModelResponse[_T]:
|
||||
"""Create a completion using the OpenAI API."""
|
||||
|
||||
completion_kwargs = self._get_completion_kwargs(model_name, functions, **kwargs)
|
||||
tool_calls_compat_mode = functions and "tools" not in completion_kwargs
|
||||
if "messages" in completion_kwargs:
|
||||
model_prompt += completion_kwargs["messages"]
|
||||
del completion_kwargs["messages"]
|
||||
openai_messages, completion_kwargs = self._get_chat_completion_args(
|
||||
model_prompt, model_name, functions, **kwargs
|
||||
)
|
||||
tool_calls_compat_mode = bool(functions and "tools" not in completion_kwargs)
|
||||
|
||||
cost = 0.0
|
||||
total_cost = 0.0
|
||||
attempts = 0
|
||||
while True:
|
||||
_response = await self._create_chat_completion(
|
||||
messages=model_prompt,
|
||||
_response, _cost, t_input, t_output = await self._create_chat_completion(
|
||||
messages=openai_messages,
|
||||
**completion_kwargs,
|
||||
)
|
||||
|
||||
_assistant_msg = _response.choices[0].message
|
||||
assistant_msg = AssistantChatMessage(
|
||||
content=_assistant_msg.content,
|
||||
tool_calls=(
|
||||
[AssistantToolCall(**tc.dict()) for tc in _assistant_msg.tool_calls]
|
||||
if _assistant_msg.tool_calls
|
||||
else None
|
||||
),
|
||||
)
|
||||
response = ChatModelResponse(
|
||||
response=assistant_msg,
|
||||
model_info=OPEN_AI_CHAT_MODELS[model_name],
|
||||
prompt_tokens_used=(
|
||||
_response.usage.prompt_tokens if _response.usage else 0
|
||||
),
|
||||
completion_tokens_used=(
|
||||
_response.usage.completion_tokens if _response.usage else 0
|
||||
),
|
||||
)
|
||||
cost += self._budget.update_usage_and_cost(response)
|
||||
self._logger.debug(
|
||||
f"Completion usage: {response.prompt_tokens_used} input, "
|
||||
f"{response.completion_tokens_used} output - ${round(cost, 5)}"
|
||||
)
|
||||
total_cost += _cost
|
||||
|
||||
# If parsing the response fails, append the error to the prompt, and let the
|
||||
# LLM fix its mistake(s).
|
||||
try:
|
||||
attempts += 1
|
||||
attempts += 1
|
||||
parse_errors: list[Exception] = []
|
||||
|
||||
if (
|
||||
tool_calls_compat_mode
|
||||
and assistant_msg.content
|
||||
and not assistant_msg.tool_calls
|
||||
):
|
||||
assistant_msg.tool_calls = list(
|
||||
_tool_calls_compat_extract_calls(assistant_msg.content)
|
||||
_assistant_msg = _response.choices[0].message
|
||||
|
||||
tool_calls, _errors = self._parse_assistant_tool_calls(
|
||||
_assistant_msg, tool_calls_compat_mode
|
||||
)
|
||||
parse_errors += _errors
|
||||
|
||||
assistant_msg = AssistantChatMessage(
|
||||
content=_assistant_msg.content,
|
||||
tool_calls=tool_calls or None,
|
||||
)
|
||||
|
||||
parsed_result: _T = None # type: ignore
|
||||
if not parse_errors:
|
||||
try:
|
||||
parsed_result = completion_parser(assistant_msg)
|
||||
except Exception as e:
|
||||
parse_errors.append(e)
|
||||
|
||||
if not parse_errors:
|
||||
if attempts > 1:
|
||||
self._logger.debug(
|
||||
f"Total cost for {attempts} attempts: ${round(total_cost, 5)}"
|
||||
)
|
||||
|
||||
response.parsed_result = completion_parser(assistant_msg)
|
||||
break
|
||||
except Exception as e:
|
||||
self._logger.warning(f"Parsing attempt #{attempts} failed: {e}")
|
||||
self._logger.debug(f"Parsing failed on response: '''{assistant_msg}'''")
|
||||
sentry_sdk.capture_exception(
|
||||
error=e,
|
||||
extras={"assistant_msg": assistant_msg, "i_attempt": attempts},
|
||||
return ChatModelResponse(
|
||||
response=AssistantChatMessage(
|
||||
content=_assistant_msg.content,
|
||||
tool_calls=tool_calls or None,
|
||||
),
|
||||
parsed_result=parsed_result,
|
||||
model_info=OPEN_AI_CHAT_MODELS[model_name],
|
||||
prompt_tokens_used=t_input,
|
||||
completion_tokens_used=t_output,
|
||||
)
|
||||
if attempts < self._configuration.fix_failed_parse_tries:
|
||||
model_prompt.append(assistant_msg)
|
||||
model_prompt.append(
|
||||
ChatMessage.system(
|
||||
"ERROR PARSING YOUR RESPONSE:\n\n"
|
||||
f"{e.__class__.__name__}: {e}"
|
||||
)
|
||||
|
||||
else:
|
||||
self._logger.debug(
|
||||
f"Parsing failed on response: '''{_assistant_msg}'''"
|
||||
)
|
||||
self._logger.warning(
|
||||
f"Parsing attempt #{attempts} failed: {parse_errors}"
|
||||
)
|
||||
for e in parse_errors:
|
||||
sentry_sdk.capture_exception(
|
||||
error=e,
|
||||
extras={"assistant_msg": _assistant_msg, "i_attempt": attempts},
|
||||
)
|
||||
|
||||
if attempts < self._configuration.fix_failed_parse_tries:
|
||||
openai_messages.append(_assistant_msg.dict(exclude_none=True))
|
||||
openai_messages.append(
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"ERROR PARSING YOUR RESPONSE:\n\n"
|
||||
+ "\n\n".join(
|
||||
f"{e.__class__.__name__}: {e}" for e in parse_errors
|
||||
)
|
||||
),
|
||||
}
|
||||
)
|
||||
continue
|
||||
else:
|
||||
raise
|
||||
|
||||
if attempts > 1:
|
||||
self._logger.debug(f"Total cost for {attempts} attempts: ${round(cost, 5)}")
|
||||
|
||||
return response
|
||||
raise parse_errors[0]
|
||||
|
||||
async def create_embedding(
|
||||
self,
|
||||
|
@ -504,21 +516,24 @@ class OpenAIProvider(
|
|||
self._budget.update_usage_and_cost(response)
|
||||
return response
|
||||
|
||||
def _get_completion_kwargs(
|
||||
def _get_chat_completion_args(
|
||||
self,
|
||||
model_prompt: list[ChatMessage],
|
||||
model_name: OpenAIModelName,
|
||||
functions: Optional[list[CompletionModelFunction]] = None,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
"""Get kwargs for completion API call.
|
||||
) -> tuple[list[ChatCompletionMessageParam], dict[str, Any]]:
|
||||
"""Prepare chat completion arguments and keyword arguments for API call.
|
||||
|
||||
Args:
|
||||
model: The model to use.
|
||||
kwargs: Keyword arguments to override the default values.
|
||||
model_prompt: List of ChatMessages.
|
||||
model_name: The model to use.
|
||||
functions: Optional list of functions available to the LLM.
|
||||
kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
The kwargs for the chat API call.
|
||||
|
||||
list[ChatCompletionMessageParam]: Prompt messages for the OpenAI call
|
||||
dict[str, Any]: Any other kwargs for the OpenAI call
|
||||
"""
|
||||
kwargs.update(self._credentials.get_model_access_kwargs(model_name))
|
||||
|
||||
|
@ -541,7 +556,19 @@ class OpenAIProvider(
|
|||
kwargs["extra_headers"] = kwargs.get("extra_headers", {})
|
||||
kwargs["extra_headers"].update(extra_headers.copy())
|
||||
|
||||
return kwargs
|
||||
if "messages" in kwargs:
|
||||
model_prompt += kwargs["messages"]
|
||||
del kwargs["messages"]
|
||||
|
||||
openai_messages: list[ChatCompletionMessageParam] = [
|
||||
message.dict(
|
||||
include={"role", "content", "tool_calls", "name"},
|
||||
exclude_none=True,
|
||||
)
|
||||
for message in model_prompt
|
||||
]
|
||||
|
||||
return openai_messages, kwargs
|
||||
|
||||
def _get_embedding_kwargs(
|
||||
self,
|
||||
|
@ -566,28 +593,106 @@ class OpenAIProvider(
|
|||
|
||||
return kwargs
|
||||
|
||||
def _create_chat_completion(
|
||||
self, messages: list[ChatMessage], *_, **kwargs
|
||||
) -> Coroutine[None, None, ChatCompletion]:
|
||||
"""Create a chat completion using the OpenAI API with retry handling."""
|
||||
async def _create_chat_completion(
|
||||
self,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
model: OpenAIModelName,
|
||||
*_,
|
||||
**kwargs,
|
||||
) -> tuple[ChatCompletion, float, int, int]:
|
||||
"""
|
||||
Create a chat completion using the OpenAI API with retry handling.
|
||||
|
||||
Params:
|
||||
openai_messages: List of OpenAI-consumable message dict objects
|
||||
model: The model to use for the completion
|
||||
|
||||
Returns:
|
||||
ChatCompletion: The chat completion response object
|
||||
float: The cost ($) of this completion
|
||||
int: Number of prompt tokens used
|
||||
int: Number of completion tokens used
|
||||
"""
|
||||
|
||||
@self._retry_api_request
|
||||
async def _create_chat_completion_with_retry(
|
||||
messages: list[ChatMessage], *_, **kwargs
|
||||
messages: list[ChatCompletionMessageParam], **kwargs
|
||||
) -> ChatCompletion:
|
||||
raw_messages = [
|
||||
message.dict(
|
||||
include={"role", "content", "tool_calls", "name"},
|
||||
exclude_none=True,
|
||||
)
|
||||
for message in messages
|
||||
]
|
||||
return await self._client.chat.completions.create(
|
||||
messages=raw_messages, # type: ignore
|
||||
messages=messages, # type: ignore
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return _create_chat_completion_with_retry(messages, *_, **kwargs)
|
||||
completion = await _create_chat_completion_with_retry(
|
||||
messages, model=model, **kwargs
|
||||
)
|
||||
|
||||
if completion.usage:
|
||||
prompt_tokens_used = completion.usage.prompt_tokens
|
||||
completion_tokens_used = completion.usage.completion_tokens
|
||||
else:
|
||||
prompt_tokens_used = completion_tokens_used = 0
|
||||
|
||||
cost = self._budget.update_usage_and_cost(
|
||||
ChatModelResponse(
|
||||
response=AssistantChatMessage(content=None),
|
||||
model_info=OPEN_AI_CHAT_MODELS[model],
|
||||
prompt_tokens_used=prompt_tokens_used,
|
||||
completion_tokens_used=completion_tokens_used,
|
||||
)
|
||||
)
|
||||
self._logger.debug(
|
||||
f"Completion usage: {prompt_tokens_used} input, "
|
||||
f"{completion_tokens_used} output - ${round(cost, 5)}"
|
||||
)
|
||||
return completion, cost, prompt_tokens_used, completion_tokens_used
|
||||
|
||||
def _parse_assistant_tool_calls(
|
||||
self, assistant_message: ChatCompletionMessage, compat_mode: bool = False
|
||||
):
|
||||
tool_calls: list[AssistantToolCall] = []
|
||||
parse_errors: list[Exception] = []
|
||||
|
||||
if assistant_message.tool_calls:
|
||||
for _tc in assistant_message.tool_calls:
|
||||
try:
|
||||
parsed_arguments = json_loads(_tc.function.arguments)
|
||||
except Exception as e:
|
||||
err_message = (
|
||||
f"Decoding arguments for {_tc.function.name} failed: "
|
||||
+ str(e.args[0])
|
||||
)
|
||||
parse_errors.append(
|
||||
type(e)(err_message, *e.args[1:]).with_traceback(
|
||||
e.__traceback__
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
tool_calls.append(
|
||||
AssistantToolCall(
|
||||
id=_tc.id,
|
||||
type=_tc.type,
|
||||
function=AssistantFunctionCall(
|
||||
name=_tc.function.name,
|
||||
arguments=parsed_arguments,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# If parsing of all tool calls succeeds in the end, we ignore any issues
|
||||
if len(tool_calls) == len(assistant_message.tool_calls):
|
||||
parse_errors = []
|
||||
|
||||
elif compat_mode and assistant_message.content:
|
||||
try:
|
||||
tool_calls = list(
|
||||
_tool_calls_compat_extract_calls(assistant_message.content)
|
||||
)
|
||||
except Exception as e:
|
||||
parse_errors.append(e)
|
||||
|
||||
return tool_calls, parse_errors
|
||||
|
||||
def _create_embedding(
|
||||
self, text: str, *_, **kwargs
|
||||
|
|
|
@ -2,6 +2,7 @@ import abc
|
|||
import enum
|
||||
import math
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
ClassVar,
|
||||
Generic,
|
||||
|
@ -68,12 +69,12 @@ class ChatMessageDict(TypedDict):
|
|||
|
||||
class AssistantFunctionCall(BaseModel):
|
||||
name: str
|
||||
arguments: str
|
||||
arguments: dict[str, Any]
|
||||
|
||||
|
||||
class AssistantFunctionCallDict(TypedDict):
|
||||
name: str
|
||||
arguments: str
|
||||
arguments: dict[str, Any]
|
||||
|
||||
|
||||
class AssistantToolCall(BaseModel):
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
import io
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
@ -32,16 +31,18 @@ def json_loads(json_str: str) -> Any:
|
|||
if match:
|
||||
json_str = match.group(1).strip()
|
||||
|
||||
error_buffer = io.StringIO()
|
||||
json_result = demjson3.decode(
|
||||
json_str, return_errors=True, write_errors=error_buffer
|
||||
)
|
||||
json_result = demjson3.decode(json_str, return_errors=True)
|
||||
assert json_result is not None # by virtue of return_errors=True
|
||||
|
||||
if error_buffer.getvalue():
|
||||
logger.debug(f"JSON parse errors:\n{error_buffer.getvalue()}")
|
||||
if json_result.errors:
|
||||
logger.debug(
|
||||
"JSON parse errors:\n" + "\n".join(str(e) for e in json_result.errors)
|
||||
)
|
||||
|
||||
if json_result is None:
|
||||
raise ValueError(f"Failed to parse JSON string: {json_str}")
|
||||
if json_result.object is demjson3.undefined:
|
||||
raise ValueError(
|
||||
f"Failed to parse JSON string: {json_str}", *json_result.errors
|
||||
)
|
||||
|
||||
return json_result.object
|
||||
|
||||
|
|
Loading…
Reference in New Issue