fix: Implement self-correction for invalid LLM responses

- Fix the parsing of invalid LLM responses by appending an error message to the prompt and allowing the LLM to fix its mistakes.
- Update the `OpenAIProvider` to handle the self-correction process and limit the number of attempts to fix parsing errors.
- Update the `BaseAgent` to profit from the new pasing and parse-fixing mechanism.

This change ensures that the system can handle and recover from errors in parsing LLM responses.

Hopefully this fixes #1407 once and for all.
pull/6563/head
Reinier van der Leer 2023-12-13 22:41:55 +01:00
parent 6b0d0d4dc8
commit acf4df9f87
No known key found for this signature in database
GPG Key ID: CDC1180FDAE06193
5 changed files with 67 additions and 46 deletions

View File

@ -15,9 +15,9 @@ from pydantic import Field
from autogpt.core.configuration import Configurable
from autogpt.core.prompting import ChatPrompt
from autogpt.core.resource.model_providers import (
AssistantChatMessageDict,
ChatMessage,
ChatModelProvider,
ChatModelResponse,
)
from autogpt.llm.api_manager import ApiManager
from autogpt.logs.log_cycle import (
@ -44,7 +44,12 @@ from .prompt_strategies.one_shot import (
OneShotAgentPromptConfiguration,
OneShotAgentPromptStrategy,
)
from .utils.exceptions import AgentException, CommandExecutionError, UnknownCommandError
from .utils.exceptions import (
AgentException,
AgentTerminated,
CommandExecutionError,
UnknownCommandError,
)
logger = logging.getLogger(__name__)
@ -76,6 +81,8 @@ class Agent(
description=__doc__,
)
prompt_strategy: OneShotAgentPromptStrategy
def __init__(
self,
settings: AgentSettings,
@ -164,20 +171,20 @@ class Agent(
return prompt
def parse_and_process_response(
self, llm_response: ChatModelResponse, *args, **kwargs
self, llm_response: AssistantChatMessageDict, *args, **kwargs
) -> Agent.ThoughtProcessOutput:
for plugin in self.config.plugins:
if not plugin.can_handle_post_planning():
continue
llm_response.response["content"] = plugin.post_planning(
llm_response.response.get("content", "")
llm_response["content"] = plugin.post_planning(
llm_response.get("content", "")
)
(
command_name,
arguments,
assistant_reply_dict,
) = self.prompt_strategy.parse_response_content(llm_response.response)
) = self.prompt_strategy.parse_response_content(llm_response)
self.log_cycle_handler.log_cycle(
self.ai_profile.ai_name,

View File

@ -12,6 +12,7 @@ if TYPE_CHECKING:
from autogpt.config import Config
from autogpt.core.prompting.base import PromptStrategy
from autogpt.core.resource.model_providers.schema import (
AssistantChatMessageDict,
ChatModelInfo,
ChatModelProvider,
ChatModelResponse,
@ -247,7 +248,7 @@ class BaseAgent(Configurable[BaseAgentSettings], ABC):
prompt = self.on_before_think(prompt, scratchpad=self._prompt_scratchpad)
logger.debug(f"Executing prompt:\n{dump_prompt(prompt)}")
raw_response = await self.llm_provider.create_chat_completion(
response = await self.llm_provider.create_chat_completion(
prompt.messages,
functions=get_openai_command_specs(
self.command_registry.list_available_commands(self)
@ -256,11 +257,16 @@ class BaseAgent(Configurable[BaseAgentSettings], ABC):
if self.config.use_functions_api
else [],
model_name=self.llm.name,
completion_parser=lambda r: self.parse_and_process_response(
r,
prompt,
scratchpad=self._prompt_scratchpad,
),
)
self.config.cycle_count += 1
return self.on_response(
llm_response=raw_response,
llm_response=response,
prompt=prompt,
scratchpad=self._prompt_scratchpad,
)
@ -397,18 +403,14 @@ class BaseAgent(Configurable[BaseAgentSettings], ABC):
The parsed command name and command args, if any, and the agent thoughts.
"""
return self.parse_and_process_response(
llm_response,
prompt,
scratchpad=scratchpad,
)
return llm_response.parsed_result
# TODO: update memory/context
@abstractmethod
def parse_and_process_response(
self,
llm_response: ChatModelResponse,
llm_response: AssistantChatMessageDict,
prompt: ChatPrompt,
scratchpad: PromptScratchpad,
) -> ThoughtProcessOutput:

View File

@ -165,6 +165,7 @@ OPEN_AI_MODELS = {
class OpenAIConfiguration(ModelProviderConfiguration):
fix_failed_parse_tries: int = UserConfigurable(3)
pass
@ -363,24 +364,45 @@ class OpenAIProvider(
model_prompt += completion_kwargs["messages"]
del completion_kwargs["messages"]
response = await self._create_chat_completion(
messages=model_prompt,
**completion_kwargs,
)
response_args = {
"model_info": OPEN_AI_CHAT_MODELS[model_name],
"prompt_tokens_used": response.usage.prompt_tokens,
"completion_tokens_used": response.usage.completion_tokens,
}
response_message = response.choices[0].message.to_dict_recursive()
if tool_calls_compat_mode:
response_message["tool_calls"] = _tool_calls_compat_extract_calls(
response_message["content"]
attempts = 0
while True:
response = await self._create_chat_completion(
messages=model_prompt,
**completion_kwargs,
)
response_args = {
"model_info": OPEN_AI_CHAT_MODELS[model_name],
"prompt_tokens_used": response.usage.prompt_tokens,
"completion_tokens_used": response.usage.completion_tokens,
}
response_message = response.choices[0].message.to_dict_recursive()
if tool_calls_compat_mode:
response_message["tool_calls"] = _tool_calls_compat_extract_calls(
response_message["content"]
)
# If parsing the response fails, append the error to the prompt, and let the
# LLM fix its mistake(s).
try:
attempts += 1
parsed_response = completion_parser(response_message)
break
except Exception as e:
self._logger.warning(f"Parsing attempt #{attempts} failed: {e}")
self._logger.debug(
f"Parsing failed on response: '''{response_message}'''"
)
if attempts < self._configuration.fix_failed_parse_tries:
model_prompt.append(
ChatMessage.system(f"ERROR PARSING YOUR RESPONSE:\n\n{e}")
)
else:
raise
response = ChatModelResponse(
response=response_message,
parsed_result=completion_parser(response_message),
parsed_result=parsed_response,
**response_args,
)
self._budget.update_usage_and_cost(response)

View File

@ -103,18 +103,8 @@ class JSONSchema(BaseModel):
validator = Draft7Validator(self.to_dict())
if errors := sorted(validator.iter_errors(object), key=lambda e: e.path):
for error in errors:
logger.debug(f"JSON Validation Error: {error}")
logger.error(json.dumps(object, indent=4))
logger.error("The following issues were found:")
for error in errors:
logger.error(f"Error: {error.message}")
return False, errors
logger.debug("The JSON object is valid.")
return True, None
def to_typescript_object_interface(self, interface_name: str = "") -> str:

View File

@ -26,10 +26,10 @@ def extract_dict_from_response(response_content: str) -> dict[str, Any]:
# Response content comes from OpenAI as a Python `str(content_dict)`.
# `literal_eval` does the reverse of `str(dict)`.
try:
return ast.literal_eval(response_content)
except BaseException as e:
logger.info(f"Error parsing JSON response with literal_eval {e}")
logger.debug(f"Invalid JSON received in response:\n{response_content}")
# TODO: How to raise an error here without causing the program to exit?
return {}
result = ast.literal_eval(response_content)
if not isinstance(result, dict):
raise ValueError(
f"Response '''{response_content}''' evaluated to "
f"non-dict value {repr(result)}"
)
return result