Improve the accuracy of the extract_dict_from_response method's JSON extraction (#5458)
parent
787c71a9de
commit
6664eec8ce
|
@ -1,4 +1,5 @@
|
|||
"""Utilities for the json_fixes package."""
|
||||
import re
|
||||
import ast
|
||||
import logging
|
||||
from typing import Any
|
||||
|
@ -8,12 +9,20 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
def extract_dict_from_response(response_content: str) -> dict[str, Any]:
|
||||
# Sometimes the response includes the JSON in a code block with ```
|
||||
if response_content.startswith("```") and response_content.endswith("```"):
|
||||
# Discard the first and last ```, then re-join in case the response naturally included ```
|
||||
response_content = "```".join(response_content.split("```")[1:-1]).strip()
|
||||
pattern = r'```([\s\S]*?)```'
|
||||
match = re.search(pattern, response_content)
|
||||
|
||||
if (ob_pos := response_content.index("{")) > 0:
|
||||
response_content = response_content[ob_pos:]
|
||||
if match:
|
||||
response_content = match.group(1).strip()
|
||||
# Remove language names in code blocks
|
||||
response_content = response_content.lstrip("json")
|
||||
else:
|
||||
# The string may contain JSON.
|
||||
json_pattern = r'{.*}'
|
||||
match = re.search(json_pattern, response_content)
|
||||
|
||||
if match:
|
||||
response_content = match.group()
|
||||
|
||||
# response content comes from OpenAI as a Python `str(content_dict)`, literal_eval reverses this
|
||||
try:
|
||||
|
|
|
@ -193,3 +193,15 @@ def test_extract_json_from_response_wrapped_in_code_block(valid_json_response: d
|
|||
assert (
|
||||
extract_dict_from_response(emulated_response_from_openai) == valid_json_response
|
||||
)
|
||||
|
||||
def test_extract_json_from_response_wrapped_in_code_block_with_language(valid_json_response: dict):
|
||||
emulated_response_from_openai = "```json" + str(valid_json_response) + "```"
|
||||
assert (
|
||||
extract_dict_from_response(emulated_response_from_openai) == valid_json_response
|
||||
)
|
||||
|
||||
def test_extract_json_from_response_json_contained_in_string(valid_json_response: dict):
|
||||
emulated_response_from_openai = "sentence1" + str(valid_json_response) + "sentence2"
|
||||
assert (
|
||||
extract_dict_from_response(emulated_response_from_openai) == valid_json_response
|
||||
)
|
Loading…
Reference in New Issue