Merge pull request #1610 from adityaoke/adityaoke/fix_json_str

[1607] Sourcery is detecting linting issues in autogpt/json_fixes/aut…
pull/1455/head
Richard Beales 2023-04-15 18:33:34 +01:00 committed by GitHub
commit 51fc59b45f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 9 additions and 8 deletions

View File

@ -1,16 +1,17 @@
"""This module contains the function to fix JSON strings using GPT-3."""
import json
from autogpt.llm_utils import call_ai_function
from autogpt.logs import logger
from autogpt.config import Config
cfg = Config()
def fix_json(json_str: str, schema: str) -> str:
def fix_json(json_string: str, schema: str) -> str:
"""Fix the given JSON string to make it parseable and fully compliant with the provided schema."""
# Try to fix the JSON using GPT:
function_string = "def fix_json(json_str: str, schema:str=None) -> str:"
args = [f"'''{json_str}'''", f"'''{schema}'''"]
function_string = "def fix_json(json_string: str, schema:str=None) -> str:"
args = [f"'''{json_string}'''", f"'''{schema}'''"]
description_string = (
"Fixes the provided JSON string to make it parseable"
" and fully compliant with the provided schema.\n If an object or"
@ -20,13 +21,13 @@ def fix_json(json_str: str, schema: str) -> str:
)
# If it doesn't already start with a "`", add one:
if not json_str.startswith("`"):
json_str = "```json\n" + json_str + "\n```"
if not json_string.startswith("`"):
json_string = "```json\n" + json_string + "\n```"
result_string = call_ai_function(
function_string, args, description_string, model=cfg.fast_llm_model
)
logger.debug("------------ JSON FIX ATTEMPT ---------------")
logger.debug(f"Original JSON: {json_str}")
logger.debug(f"Original JSON: {json_string}")
logger.debug("-----------")
logger.debug(f"Fixed JSON: {result_string}")
logger.debug("----------- END OF FIX ATTEMPT ----------------")
@ -34,9 +35,9 @@ def fix_json(json_str: str, schema: str) -> str:
try:
json.loads(result_string) # just check the validity
return result_string
except: # noqa: E722
except json.JSONDecodeError: # noqa: E722
# Get the call stack:
# import traceback
# call_stack = traceback.format_exc()
# print(f"Failed to fix JSON: '{json_str}' "+call_stack)
# print(f"Failed to fix JSON: '{json_string}' "+call_stack)
return "failed"