feat(blocks): Add summariser block for recursive text summarization functionality (#7431)

Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
pull/7437/head
Toran Bruce Richards 2024-07-16 11:51:37 +01:00 committed by GitHub
parent e874318832
commit 920f931a21
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 128 additions and 8 deletions

View File

@ -2,8 +2,8 @@ import logging
from enum import Enum
from typing import NamedTuple
import openai
import anthropic
import openai
from groq import Groq
from autogpt_server.data.block import Block, BlockOutput, BlockSchema, BlockFieldSecret
@ -92,15 +92,21 @@ class LlmCallBlock(Block):
)
@staticmethod
def llm_call(api_key: str, model: LlmModel, prompt: list[dict], json_format: bool) -> str:
def llm_call(
api_key: str,
model: LlmModel,
prompt: list[dict],
json_format: bool
) -> str:
provider = model.metadata.provider
if provider == "openai":
openai.api_key = api_key
response_format = {"type": "json_object"} if json_format else None
response = openai.chat.completions.create(
model=model.value,
messages=prompt, # type: ignore
response_format={"type": "json_object"} if json_format else None, # type: ignore
response_format=response_format, # type: ignore
)
return response.choices[0].message.content or ""
elif provider == "anthropic":
@ -116,10 +122,11 @@ class LlmCallBlock(Block):
return response.content[0].text if response.content else ""
elif provider == "groq":
client = Groq(api_key=api_key)
response_format = {"type": "json_object"} if json_format else None
response = client.chat.completions.create(
model=model.value,
messages=prompt, # type: ignore
response_format={"type": "json_object"} if json_format else None,
response_format=response_format, # type: ignore
)
return response.choices[0].message.content or ""
else:
@ -136,7 +143,10 @@ class LlmCallBlock(Block):
prompt.append({"role": "system", "content": input_data.sys_prompt})
if input_data.expected_format:
expected_format = [f'"{k}": "{v}"' for k, v in input_data.expected_format.items()]
expected_format = [
f'"{k}": "{v}"' for k, v in
input_data.expected_format.items()
]
format_prompt = ",\n ".join(expected_format)
sys_prompt = trim_prompt(f"""
|Reply in json format:
@ -199,3 +209,113 @@ class LlmCallBlock(Block):
retry_prompt = f"Error calling LLM: {e}"
yield "error", retry_prompt
class TextSummarizerBlock(Block):
class Input(BlockSchema):
text: str
model: LlmModel = LlmModel.GPT4_TURBO
api_key: BlockFieldSecret = BlockFieldSecret(value="")
# TODO: Make this dynamic
max_tokens: int = 4000 # Adjust based on the model's context window
chunk_overlap: int = 100 # Overlap between chunks to maintain context
class Output(BlockSchema):
summary: str
error: str
def __init__(self):
super().__init__(
id="c3d4e5f6-7g8h-9i0j-1k2l-m3n4o5p6q7r8",
input_schema=TextSummarizerBlock.Input,
output_schema=TextSummarizerBlock.Output,
test_input={"text": "Lorem ipsum..." * 100},
test_output=("summary", "Final summary of a long text"),
test_mock={
"llm_call": lambda input_data:
{"final_summary": "Final summary of a long text"}
if "final_summary" in input_data.expected_format
else {"summary": "Summary of a chunk of text"}
}
)
def run(self, input_data: Input) -> BlockOutput:
try:
for output in self._run(input_data):
yield output
except Exception as e:
yield "error", str(e)
def _run(self, input_data: Input) -> BlockOutput:
chunks = self._split_text(
input_data.text,
input_data.max_tokens,
input_data.chunk_overlap
)
summaries = []
for chunk in chunks:
chunk_summary = self._summarize_chunk(chunk, input_data)
summaries.append(chunk_summary)
final_summary = self._combine_summaries(summaries, input_data)
yield "summary", final_summary
@staticmethod
def _split_text(text: str, max_tokens: int, overlap: int) -> list[str]:
words = text.split()
chunks = []
chunk_size = max_tokens - overlap
for i in range(0, len(words), chunk_size):
chunk = " ".join(words[i:i + max_tokens])
chunks.append(chunk)
return chunks
@staticmethod
def llm_call(input_data: LlmCallBlock.Input) -> dict[str, str]:
llm_block = LlmCallBlock()
for output_name, output_data in llm_block.run(input_data):
if output_name == "response":
return output_data
raise ValueError("Failed to get a response from the LLM.")
def _summarize_chunk(self, chunk: str, input_data: Input) -> str:
prompt = f"Summarize the following text concisely:\n\n{chunk}"
llm_response = self.llm_call(LlmCallBlock.Input(
prompt=prompt,
api_key=input_data.api_key,
model=input_data.model,
expected_format={"summary": "The summary of the given text."}
))
return llm_response["summary"]
def _combine_summaries(self, summaries: list[str], input_data: Input) -> str:
combined_text = " ".join(summaries)
if len(combined_text.split()) <= input_data.max_tokens:
prompt = ("Provide a final, concise summary of the following summaries:\n\n"
+ combined_text)
llm_response = self.llm_call(LlmCallBlock.Input(
prompt=prompt,
api_key=input_data.api_key,
model=input_data.model,
expected_format={
"final_summary": "The final summary of all provided summaries."
}
))
return llm_response["final_summary"]
else:
# If combined summaries are still too long, recursively summarize
return self._run(TextSummarizerBlock.Input(
text=combined_text,
api_key=input_data.api_key,
model=input_data.model,
max_tokens=input_data.max_tokens,
chunk_overlap=input_data.chunk_overlap
)).send(None)[1] # Get the first yielded value

View File

@ -1,6 +1,6 @@
from pathlib import Path
from autogpt_server.blocks.ai import LlmCallBlock
from autogpt_server.blocks.llm import LlmCallBlock
from autogpt_server.blocks.basic import ValueBlock
from autogpt_server.blocks.block import BlockInstallationBlock
from autogpt_server.blocks.http import HttpRequestBlock

View File

@ -1,5 +1,5 @@
from autogpt_server.data.graph import Graph, Link, Node, create_graph
from autogpt_server.blocks.ai import LlmCallBlock
from autogpt_server.blocks.llm import LlmCallBlock
from autogpt_server.blocks.reddit import (
RedditGetPostsBlock,
RedditPostCommentBlock,