fix: Implement self-correction for invalid LLM responses
- Fix the parsing of invalid LLM responses by appending an error message to the prompt and allowing the LLM to fix its mistakes. - Update the `OpenAIProvider` to handle the self-correction process and limit the number of attempts to fix parsing errors. - Update the `BaseAgent` to profit from the new pasing and parse-fixing mechanism. This change ensures that the system can handle and recover from errors in parsing LLM responses. Hopefully this fixes #1407 once and for all.pull/6563/head
parent
6b0d0d4dc8
commit
acf4df9f87
|
@ -15,9 +15,9 @@ from pydantic import Field
|
|||
from autogpt.core.configuration import Configurable
|
||||
from autogpt.core.prompting import ChatPrompt
|
||||
from autogpt.core.resource.model_providers import (
|
||||
AssistantChatMessageDict,
|
||||
ChatMessage,
|
||||
ChatModelProvider,
|
||||
ChatModelResponse,
|
||||
)
|
||||
from autogpt.llm.api_manager import ApiManager
|
||||
from autogpt.logs.log_cycle import (
|
||||
|
@ -44,7 +44,12 @@ from .prompt_strategies.one_shot import (
|
|||
OneShotAgentPromptConfiguration,
|
||||
OneShotAgentPromptStrategy,
|
||||
)
|
||||
from .utils.exceptions import AgentException, CommandExecutionError, UnknownCommandError
|
||||
from .utils.exceptions import (
|
||||
AgentException,
|
||||
AgentTerminated,
|
||||
CommandExecutionError,
|
||||
UnknownCommandError,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -76,6 +81,8 @@ class Agent(
|
|||
description=__doc__,
|
||||
)
|
||||
|
||||
prompt_strategy: OneShotAgentPromptStrategy
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings: AgentSettings,
|
||||
|
@ -164,20 +171,20 @@ class Agent(
|
|||
return prompt
|
||||
|
||||
def parse_and_process_response(
|
||||
self, llm_response: ChatModelResponse, *args, **kwargs
|
||||
self, llm_response: AssistantChatMessageDict, *args, **kwargs
|
||||
) -> Agent.ThoughtProcessOutput:
|
||||
for plugin in self.config.plugins:
|
||||
if not plugin.can_handle_post_planning():
|
||||
continue
|
||||
llm_response.response["content"] = plugin.post_planning(
|
||||
llm_response.response.get("content", "")
|
||||
llm_response["content"] = plugin.post_planning(
|
||||
llm_response.get("content", "")
|
||||
)
|
||||
|
||||
(
|
||||
command_name,
|
||||
arguments,
|
||||
assistant_reply_dict,
|
||||
) = self.prompt_strategy.parse_response_content(llm_response.response)
|
||||
) = self.prompt_strategy.parse_response_content(llm_response)
|
||||
|
||||
self.log_cycle_handler.log_cycle(
|
||||
self.ai_profile.ai_name,
|
||||
|
|
|
@ -12,6 +12,7 @@ if TYPE_CHECKING:
|
|||
from autogpt.config import Config
|
||||
from autogpt.core.prompting.base import PromptStrategy
|
||||
from autogpt.core.resource.model_providers.schema import (
|
||||
AssistantChatMessageDict,
|
||||
ChatModelInfo,
|
||||
ChatModelProvider,
|
||||
ChatModelResponse,
|
||||
|
@ -247,7 +248,7 @@ class BaseAgent(Configurable[BaseAgentSettings], ABC):
|
|||
prompt = self.on_before_think(prompt, scratchpad=self._prompt_scratchpad)
|
||||
|
||||
logger.debug(f"Executing prompt:\n{dump_prompt(prompt)}")
|
||||
raw_response = await self.llm_provider.create_chat_completion(
|
||||
response = await self.llm_provider.create_chat_completion(
|
||||
prompt.messages,
|
||||
functions=get_openai_command_specs(
|
||||
self.command_registry.list_available_commands(self)
|
||||
|
@ -256,11 +257,16 @@ class BaseAgent(Configurable[BaseAgentSettings], ABC):
|
|||
if self.config.use_functions_api
|
||||
else [],
|
||||
model_name=self.llm.name,
|
||||
completion_parser=lambda r: self.parse_and_process_response(
|
||||
r,
|
||||
prompt,
|
||||
scratchpad=self._prompt_scratchpad,
|
||||
),
|
||||
)
|
||||
self.config.cycle_count += 1
|
||||
|
||||
return self.on_response(
|
||||
llm_response=raw_response,
|
||||
llm_response=response,
|
||||
prompt=prompt,
|
||||
scratchpad=self._prompt_scratchpad,
|
||||
)
|
||||
|
@ -397,18 +403,14 @@ class BaseAgent(Configurable[BaseAgentSettings], ABC):
|
|||
The parsed command name and command args, if any, and the agent thoughts.
|
||||
"""
|
||||
|
||||
return self.parse_and_process_response(
|
||||
llm_response,
|
||||
prompt,
|
||||
scratchpad=scratchpad,
|
||||
)
|
||||
return llm_response.parsed_result
|
||||
|
||||
# TODO: update memory/context
|
||||
|
||||
@abstractmethod
|
||||
def parse_and_process_response(
|
||||
self,
|
||||
llm_response: ChatModelResponse,
|
||||
llm_response: AssistantChatMessageDict,
|
||||
prompt: ChatPrompt,
|
||||
scratchpad: PromptScratchpad,
|
||||
) -> ThoughtProcessOutput:
|
||||
|
|
|
@ -165,6 +165,7 @@ OPEN_AI_MODELS = {
|
|||
|
||||
|
||||
class OpenAIConfiguration(ModelProviderConfiguration):
|
||||
fix_failed_parse_tries: int = UserConfigurable(3)
|
||||
pass
|
||||
|
||||
|
||||
|
@ -363,24 +364,45 @@ class OpenAIProvider(
|
|||
model_prompt += completion_kwargs["messages"]
|
||||
del completion_kwargs["messages"]
|
||||
|
||||
response = await self._create_chat_completion(
|
||||
messages=model_prompt,
|
||||
**completion_kwargs,
|
||||
)
|
||||
response_args = {
|
||||
"model_info": OPEN_AI_CHAT_MODELS[model_name],
|
||||
"prompt_tokens_used": response.usage.prompt_tokens,
|
||||
"completion_tokens_used": response.usage.completion_tokens,
|
||||
}
|
||||
|
||||
response_message = response.choices[0].message.to_dict_recursive()
|
||||
if tool_calls_compat_mode:
|
||||
response_message["tool_calls"] = _tool_calls_compat_extract_calls(
|
||||
response_message["content"]
|
||||
attempts = 0
|
||||
while True:
|
||||
response = await self._create_chat_completion(
|
||||
messages=model_prompt,
|
||||
**completion_kwargs,
|
||||
)
|
||||
response_args = {
|
||||
"model_info": OPEN_AI_CHAT_MODELS[model_name],
|
||||
"prompt_tokens_used": response.usage.prompt_tokens,
|
||||
"completion_tokens_used": response.usage.completion_tokens,
|
||||
}
|
||||
|
||||
response_message = response.choices[0].message.to_dict_recursive()
|
||||
if tool_calls_compat_mode:
|
||||
response_message["tool_calls"] = _tool_calls_compat_extract_calls(
|
||||
response_message["content"]
|
||||
)
|
||||
|
||||
# If parsing the response fails, append the error to the prompt, and let the
|
||||
# LLM fix its mistake(s).
|
||||
try:
|
||||
attempts += 1
|
||||
parsed_response = completion_parser(response_message)
|
||||
break
|
||||
except Exception as e:
|
||||
self._logger.warning(f"Parsing attempt #{attempts} failed: {e}")
|
||||
self._logger.debug(
|
||||
f"Parsing failed on response: '''{response_message}'''"
|
||||
)
|
||||
if attempts < self._configuration.fix_failed_parse_tries:
|
||||
model_prompt.append(
|
||||
ChatMessage.system(f"ERROR PARSING YOUR RESPONSE:\n\n{e}")
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
response = ChatModelResponse(
|
||||
response=response_message,
|
||||
parsed_result=completion_parser(response_message),
|
||||
parsed_result=parsed_response,
|
||||
**response_args,
|
||||
)
|
||||
self._budget.update_usage_and_cost(response)
|
||||
|
|
|
@ -103,18 +103,8 @@ class JSONSchema(BaseModel):
|
|||
validator = Draft7Validator(self.to_dict())
|
||||
|
||||
if errors := sorted(validator.iter_errors(object), key=lambda e: e.path):
|
||||
for error in errors:
|
||||
logger.debug(f"JSON Validation Error: {error}")
|
||||
|
||||
logger.error(json.dumps(object, indent=4))
|
||||
logger.error("The following issues were found:")
|
||||
|
||||
for error in errors:
|
||||
logger.error(f"Error: {error.message}")
|
||||
return False, errors
|
||||
|
||||
logger.debug("The JSON object is valid.")
|
||||
|
||||
return True, None
|
||||
|
||||
def to_typescript_object_interface(self, interface_name: str = "") -> str:
|
||||
|
|
|
@ -26,10 +26,10 @@ def extract_dict_from_response(response_content: str) -> dict[str, Any]:
|
|||
|
||||
# Response content comes from OpenAI as a Python `str(content_dict)`.
|
||||
# `literal_eval` does the reverse of `str(dict)`.
|
||||
try:
|
||||
return ast.literal_eval(response_content)
|
||||
except BaseException as e:
|
||||
logger.info(f"Error parsing JSON response with literal_eval {e}")
|
||||
logger.debug(f"Invalid JSON received in response:\n{response_content}")
|
||||
# TODO: How to raise an error here without causing the program to exit?
|
||||
return {}
|
||||
result = ast.literal_eval(response_content)
|
||||
if not isinstance(result, dict):
|
||||
raise ValueError(
|
||||
f"Response '''{response_content}''' evaluated to "
|
||||
f"non-dict value {repr(result)}"
|
||||
)
|
||||
return result
|
||||
|
|
Loading…
Reference in New Issue