feat(blocks): Add summariser block for recursive text summarization functionality (#7431)
Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>pull/7437/head
parent
e874318832
commit
920f931a21
|
@ -2,8 +2,8 @@ import logging
|
|||
from enum import Enum
|
||||
from typing import NamedTuple
|
||||
|
||||
import openai
|
||||
import anthropic
|
||||
import openai
|
||||
from groq import Groq
|
||||
|
||||
from autogpt_server.data.block import Block, BlockOutput, BlockSchema, BlockFieldSecret
|
||||
|
@ -92,15 +92,21 @@ class LlmCallBlock(Block):
|
|||
)
|
||||
|
||||
@staticmethod
|
||||
def llm_call(api_key: str, model: LlmModel, prompt: list[dict], json_format: bool) -> str:
|
||||
def llm_call(
|
||||
api_key: str,
|
||||
model: LlmModel,
|
||||
prompt: list[dict],
|
||||
json_format: bool
|
||||
) -> str:
|
||||
provider = model.metadata.provider
|
||||
|
||||
if provider == "openai":
|
||||
openai.api_key = api_key
|
||||
response_format = {"type": "json_object"} if json_format else None
|
||||
response = openai.chat.completions.create(
|
||||
model=model.value,
|
||||
messages=prompt, # type: ignore
|
||||
response_format={"type": "json_object"} if json_format else None, # type: ignore
|
||||
response_format=response_format, # type: ignore
|
||||
)
|
||||
return response.choices[0].message.content or ""
|
||||
elif provider == "anthropic":
|
||||
|
@ -116,10 +122,11 @@ class LlmCallBlock(Block):
|
|||
return response.content[0].text if response.content else ""
|
||||
elif provider == "groq":
|
||||
client = Groq(api_key=api_key)
|
||||
response_format = {"type": "json_object"} if json_format else None
|
||||
response = client.chat.completions.create(
|
||||
model=model.value,
|
||||
messages=prompt, # type: ignore
|
||||
response_format={"type": "json_object"} if json_format else None,
|
||||
response_format=response_format, # type: ignore
|
||||
)
|
||||
return response.choices[0].message.content or ""
|
||||
else:
|
||||
|
@ -136,7 +143,10 @@ class LlmCallBlock(Block):
|
|||
prompt.append({"role": "system", "content": input_data.sys_prompt})
|
||||
|
||||
if input_data.expected_format:
|
||||
expected_format = [f'"{k}": "{v}"' for k, v in input_data.expected_format.items()]
|
||||
expected_format = [
|
||||
f'"{k}": "{v}"' for k, v in
|
||||
input_data.expected_format.items()
|
||||
]
|
||||
format_prompt = ",\n ".join(expected_format)
|
||||
sys_prompt = trim_prompt(f"""
|
||||
|Reply in json format:
|
||||
|
@ -199,3 +209,113 @@ class LlmCallBlock(Block):
|
|||
retry_prompt = f"Error calling LLM: {e}"
|
||||
|
||||
yield "error", retry_prompt
|
||||
|
||||
|
||||
class TextSummarizerBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
text: str
|
||||
model: LlmModel = LlmModel.GPT4_TURBO
|
||||
api_key: BlockFieldSecret = BlockFieldSecret(value="")
|
||||
# TODO: Make this dynamic
|
||||
max_tokens: int = 4000 # Adjust based on the model's context window
|
||||
chunk_overlap: int = 100 # Overlap between chunks to maintain context
|
||||
|
||||
class Output(BlockSchema):
|
||||
summary: str
|
||||
error: str
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c3d4e5f6-7g8h-9i0j-1k2l-m3n4o5p6q7r8",
|
||||
input_schema=TextSummarizerBlock.Input,
|
||||
output_schema=TextSummarizerBlock.Output,
|
||||
test_input={"text": "Lorem ipsum..." * 100},
|
||||
test_output=("summary", "Final summary of a long text"),
|
||||
test_mock={
|
||||
"llm_call": lambda input_data:
|
||||
{"final_summary": "Final summary of a long text"}
|
||||
if "final_summary" in input_data.expected_format
|
||||
else {"summary": "Summary of a chunk of text"}
|
||||
}
|
||||
)
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
try:
|
||||
for output in self._run(input_data):
|
||||
yield output
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
def _run(self, input_data: Input) -> BlockOutput:
|
||||
chunks = self._split_text(
|
||||
input_data.text,
|
||||
input_data.max_tokens,
|
||||
input_data.chunk_overlap
|
||||
)
|
||||
summaries = []
|
||||
|
||||
for chunk in chunks:
|
||||
chunk_summary = self._summarize_chunk(chunk, input_data)
|
||||
summaries.append(chunk_summary)
|
||||
|
||||
final_summary = self._combine_summaries(summaries, input_data)
|
||||
yield "summary", final_summary
|
||||
|
||||
@staticmethod
|
||||
def _split_text(text: str, max_tokens: int, overlap: int) -> list[str]:
|
||||
words = text.split()
|
||||
chunks = []
|
||||
chunk_size = max_tokens - overlap
|
||||
|
||||
for i in range(0, len(words), chunk_size):
|
||||
chunk = " ".join(words[i:i + max_tokens])
|
||||
chunks.append(chunk)
|
||||
|
||||
return chunks
|
||||
|
||||
@staticmethod
|
||||
def llm_call(input_data: LlmCallBlock.Input) -> dict[str, str]:
|
||||
llm_block = LlmCallBlock()
|
||||
for output_name, output_data in llm_block.run(input_data):
|
||||
if output_name == "response":
|
||||
return output_data
|
||||
raise ValueError("Failed to get a response from the LLM.")
|
||||
|
||||
def _summarize_chunk(self, chunk: str, input_data: Input) -> str:
|
||||
prompt = f"Summarize the following text concisely:\n\n{chunk}"
|
||||
|
||||
llm_response = self.llm_call(LlmCallBlock.Input(
|
||||
prompt=prompt,
|
||||
api_key=input_data.api_key,
|
||||
model=input_data.model,
|
||||
expected_format={"summary": "The summary of the given text."}
|
||||
))
|
||||
|
||||
return llm_response["summary"]
|
||||
|
||||
def _combine_summaries(self, summaries: list[str], input_data: Input) -> str:
|
||||
combined_text = " ".join(summaries)
|
||||
|
||||
if len(combined_text.split()) <= input_data.max_tokens:
|
||||
prompt = ("Provide a final, concise summary of the following summaries:\n\n"
|
||||
+ combined_text)
|
||||
|
||||
llm_response = self.llm_call(LlmCallBlock.Input(
|
||||
prompt=prompt,
|
||||
api_key=input_data.api_key,
|
||||
model=input_data.model,
|
||||
expected_format={
|
||||
"final_summary": "The final summary of all provided summaries."
|
||||
}
|
||||
))
|
||||
|
||||
return llm_response["final_summary"]
|
||||
else:
|
||||
# If combined summaries are still too long, recursively summarize
|
||||
return self._run(TextSummarizerBlock.Input(
|
||||
text=combined_text,
|
||||
api_key=input_data.api_key,
|
||||
model=input_data.model,
|
||||
max_tokens=input_data.max_tokens,
|
||||
chunk_overlap=input_data.chunk_overlap
|
||||
)).send(None)[1] # Get the first yielded value
|
|
@ -1,6 +1,6 @@
|
|||
from pathlib import Path
|
||||
|
||||
from autogpt_server.blocks.ai import LlmCallBlock
|
||||
from autogpt_server.blocks.llm import LlmCallBlock
|
||||
from autogpt_server.blocks.basic import ValueBlock
|
||||
from autogpt_server.blocks.block import BlockInstallationBlock
|
||||
from autogpt_server.blocks.http import HttpRequestBlock
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from autogpt_server.data.graph import Graph, Link, Node, create_graph
|
||||
from autogpt_server.blocks.ai import LlmCallBlock
|
||||
from autogpt_server.blocks.llm import LlmCallBlock
|
||||
from autogpt_server.blocks.reddit import (
|
||||
RedditGetPostsBlock,
|
||||
RedditPostCommentBlock,
|
||||
|
|
Loading…
Reference in New Issue