feat(blocks): Add summariser block for recursive text summarization functionality (#7431)

Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
pull/7437/head
Toran Bruce Richards 2024-07-16 11:51:37 +01:00 committed by GitHub
parent e874318832
commit 920f931a21
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 128 additions and 8 deletions

View File

@ -2,8 +2,8 @@ import logging
from enum import Enum from enum import Enum
from typing import NamedTuple from typing import NamedTuple
import openai
import anthropic import anthropic
import openai
from groq import Groq from groq import Groq
from autogpt_server.data.block import Block, BlockOutput, BlockSchema, BlockFieldSecret from autogpt_server.data.block import Block, BlockOutput, BlockSchema, BlockFieldSecret
@ -92,15 +92,21 @@ class LlmCallBlock(Block):
) )
@staticmethod @staticmethod
def llm_call(api_key: str, model: LlmModel, prompt: list[dict], json_format: bool) -> str: def llm_call(
api_key: str,
model: LlmModel,
prompt: list[dict],
json_format: bool
) -> str:
provider = model.metadata.provider provider = model.metadata.provider
if provider == "openai": if provider == "openai":
openai.api_key = api_key openai.api_key = api_key
response_format = {"type": "json_object"} if json_format else None
response = openai.chat.completions.create( response = openai.chat.completions.create(
model=model.value, model=model.value,
messages=prompt, # type: ignore messages=prompt, # type: ignore
response_format={"type": "json_object"} if json_format else None, # type: ignore response_format=response_format, # type: ignore
) )
return response.choices[0].message.content or "" return response.choices[0].message.content or ""
elif provider == "anthropic": elif provider == "anthropic":
@ -116,10 +122,11 @@ class LlmCallBlock(Block):
return response.content[0].text if response.content else "" return response.content[0].text if response.content else ""
elif provider == "groq": elif provider == "groq":
client = Groq(api_key=api_key) client = Groq(api_key=api_key)
response_format = {"type": "json_object"} if json_format else None
response = client.chat.completions.create( response = client.chat.completions.create(
model=model.value, model=model.value,
messages=prompt, # type: ignore messages=prompt, # type: ignore
response_format={"type": "json_object"} if json_format else None, response_format=response_format, # type: ignore
) )
return response.choices[0].message.content or "" return response.choices[0].message.content or ""
else: else:
@ -136,7 +143,10 @@ class LlmCallBlock(Block):
prompt.append({"role": "system", "content": input_data.sys_prompt}) prompt.append({"role": "system", "content": input_data.sys_prompt})
if input_data.expected_format: if input_data.expected_format:
expected_format = [f'"{k}": "{v}"' for k, v in input_data.expected_format.items()] expected_format = [
f'"{k}": "{v}"' for k, v in
input_data.expected_format.items()
]
format_prompt = ",\n ".join(expected_format) format_prompt = ",\n ".join(expected_format)
sys_prompt = trim_prompt(f""" sys_prompt = trim_prompt(f"""
|Reply in json format: |Reply in json format:
@ -199,3 +209,113 @@ class LlmCallBlock(Block):
retry_prompt = f"Error calling LLM: {e}" retry_prompt = f"Error calling LLM: {e}"
yield "error", retry_prompt yield "error", retry_prompt
class TextSummarizerBlock(Block):
class Input(BlockSchema):
text: str
model: LlmModel = LlmModel.GPT4_TURBO
api_key: BlockFieldSecret = BlockFieldSecret(value="")
# TODO: Make this dynamic
max_tokens: int = 4000 # Adjust based on the model's context window
chunk_overlap: int = 100 # Overlap between chunks to maintain context
class Output(BlockSchema):
summary: str
error: str
def __init__(self):
super().__init__(
id="c3d4e5f6-7g8h-9i0j-1k2l-m3n4o5p6q7r8",
input_schema=TextSummarizerBlock.Input,
output_schema=TextSummarizerBlock.Output,
test_input={"text": "Lorem ipsum..." * 100},
test_output=("summary", "Final summary of a long text"),
test_mock={
"llm_call": lambda input_data:
{"final_summary": "Final summary of a long text"}
if "final_summary" in input_data.expected_format
else {"summary": "Summary of a chunk of text"}
}
)
def run(self, input_data: Input) -> BlockOutput:
try:
for output in self._run(input_data):
yield output
except Exception as e:
yield "error", str(e)
def _run(self, input_data: Input) -> BlockOutput:
chunks = self._split_text(
input_data.text,
input_data.max_tokens,
input_data.chunk_overlap
)
summaries = []
for chunk in chunks:
chunk_summary = self._summarize_chunk(chunk, input_data)
summaries.append(chunk_summary)
final_summary = self._combine_summaries(summaries, input_data)
yield "summary", final_summary
@staticmethod
def _split_text(text: str, max_tokens: int, overlap: int) -> list[str]:
words = text.split()
chunks = []
chunk_size = max_tokens - overlap
for i in range(0, len(words), chunk_size):
chunk = " ".join(words[i:i + max_tokens])
chunks.append(chunk)
return chunks
@staticmethod
def llm_call(input_data: LlmCallBlock.Input) -> dict[str, str]:
llm_block = LlmCallBlock()
for output_name, output_data in llm_block.run(input_data):
if output_name == "response":
return output_data
raise ValueError("Failed to get a response from the LLM.")
def _summarize_chunk(self, chunk: str, input_data: Input) -> str:
prompt = f"Summarize the following text concisely:\n\n{chunk}"
llm_response = self.llm_call(LlmCallBlock.Input(
prompt=prompt,
api_key=input_data.api_key,
model=input_data.model,
expected_format={"summary": "The summary of the given text."}
))
return llm_response["summary"]
def _combine_summaries(self, summaries: list[str], input_data: Input) -> str:
combined_text = " ".join(summaries)
if len(combined_text.split()) <= input_data.max_tokens:
prompt = ("Provide a final, concise summary of the following summaries:\n\n"
+ combined_text)
llm_response = self.llm_call(LlmCallBlock.Input(
prompt=prompt,
api_key=input_data.api_key,
model=input_data.model,
expected_format={
"final_summary": "The final summary of all provided summaries."
}
))
return llm_response["final_summary"]
else:
# If combined summaries are still too long, recursively summarize
return self._run(TextSummarizerBlock.Input(
text=combined_text,
api_key=input_data.api_key,
model=input_data.model,
max_tokens=input_data.max_tokens,
chunk_overlap=input_data.chunk_overlap
)).send(None)[1] # Get the first yielded value

View File

@ -1,6 +1,6 @@
from pathlib import Path from pathlib import Path
from autogpt_server.blocks.ai import LlmCallBlock from autogpt_server.blocks.llm import LlmCallBlock
from autogpt_server.blocks.basic import ValueBlock from autogpt_server.blocks.basic import ValueBlock
from autogpt_server.blocks.block import BlockInstallationBlock from autogpt_server.blocks.block import BlockInstallationBlock
from autogpt_server.blocks.http import HttpRequestBlock from autogpt_server.blocks.http import HttpRequestBlock

View File

@ -1,5 +1,5 @@
from autogpt_server.data.graph import Graph, Link, Node, create_graph from autogpt_server.data.graph import Graph, Link, Node, create_graph
from autogpt_server.blocks.ai import LlmCallBlock from autogpt_server.blocks.llm import LlmCallBlock
from autogpt_server.blocks.reddit import ( from autogpt_server.blocks.reddit import (
RedditGetPostsBlock, RedditGetPostsBlock,
RedditPostCommentBlock, RedditPostCommentBlock,