Improve the accuracy of the extract_dict_from_response method's JSON extraction (#5458)

pull/6192/merge
HawkClaws 2023-11-16 23:19:09 +09:00 committed by GitHub
parent 787c71a9de
commit 6664eec8ce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 26 additions and 5 deletions

View File

@ -1,4 +1,5 @@
"""Utilities for the json_fixes package."""
import re
import ast
import logging
from typing import Any
@ -8,12 +9,20 @@ logger = logging.getLogger(__name__)
def extract_dict_from_response(response_content: str) -> dict[str, Any]:
# Sometimes the response includes the JSON in a code block with ```
if response_content.startswith("```") and response_content.endswith("```"):
# Discard the first and last ```, then re-join in case the response naturally included ```
response_content = "```".join(response_content.split("```")[1:-1]).strip()
pattern = r'```([\s\S]*?)```'
match = re.search(pattern, response_content)
if (ob_pos := response_content.index("{")) > 0:
response_content = response_content[ob_pos:]
if match:
response_content = match.group(1).strip()
# Remove language names in code blocks
response_content = response_content.lstrip("json")
else:
# The string may contain JSON.
json_pattern = r'{.*}'
match = re.search(json_pattern, response_content)
if match:
response_content = match.group()
# response content comes from OpenAI as a Python `str(content_dict)`, literal_eval reverses this
try:

View File

@ -193,3 +193,15 @@ def test_extract_json_from_response_wrapped_in_code_block(valid_json_response: d
assert (
extract_dict_from_response(emulated_response_from_openai) == valid_json_response
)
def test_extract_json_from_response_wrapped_in_code_block_with_language(valid_json_response: dict):
emulated_response_from_openai = "```json" + str(valid_json_response) + "```"
assert (
extract_dict_from_response(emulated_response_from_openai) == valid_json_response
)
def test_extract_json_from_response_json_contained_in_string(valid_json_response: dict):
emulated_response_from_openai = "sentence1" + str(valid_json_response) + "sentence2"
assert (
extract_dict_from_response(emulated_response_from_openai) == valid_json_response
)