feat(blocks): Add summariser block for recursive text summarization functionality (#7431)
Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>pull/7437/head
parent
e874318832
commit
920f931a21
|
@ -2,8 +2,8 @@ import logging
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import NamedTuple
|
from typing import NamedTuple
|
||||||
|
|
||||||
import openai
|
|
||||||
import anthropic
|
import anthropic
|
||||||
|
import openai
|
||||||
from groq import Groq
|
from groq import Groq
|
||||||
|
|
||||||
from autogpt_server.data.block import Block, BlockOutput, BlockSchema, BlockFieldSecret
|
from autogpt_server.data.block import Block, BlockOutput, BlockSchema, BlockFieldSecret
|
||||||
|
@ -92,15 +92,21 @@ class LlmCallBlock(Block):
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def llm_call(api_key: str, model: LlmModel, prompt: list[dict], json_format: bool) -> str:
|
def llm_call(
|
||||||
|
api_key: str,
|
||||||
|
model: LlmModel,
|
||||||
|
prompt: list[dict],
|
||||||
|
json_format: bool
|
||||||
|
) -> str:
|
||||||
provider = model.metadata.provider
|
provider = model.metadata.provider
|
||||||
|
|
||||||
if provider == "openai":
|
if provider == "openai":
|
||||||
openai.api_key = api_key
|
openai.api_key = api_key
|
||||||
|
response_format = {"type": "json_object"} if json_format else None
|
||||||
response = openai.chat.completions.create(
|
response = openai.chat.completions.create(
|
||||||
model=model.value,
|
model=model.value,
|
||||||
messages=prompt, # type: ignore
|
messages=prompt, # type: ignore
|
||||||
response_format={"type": "json_object"} if json_format else None, # type: ignore
|
response_format=response_format, # type: ignore
|
||||||
)
|
)
|
||||||
return response.choices[0].message.content or ""
|
return response.choices[0].message.content or ""
|
||||||
elif provider == "anthropic":
|
elif provider == "anthropic":
|
||||||
|
@ -116,10 +122,11 @@ class LlmCallBlock(Block):
|
||||||
return response.content[0].text if response.content else ""
|
return response.content[0].text if response.content else ""
|
||||||
elif provider == "groq":
|
elif provider == "groq":
|
||||||
client = Groq(api_key=api_key)
|
client = Groq(api_key=api_key)
|
||||||
|
response_format = {"type": "json_object"} if json_format else None
|
||||||
response = client.chat.completions.create(
|
response = client.chat.completions.create(
|
||||||
model=model.value,
|
model=model.value,
|
||||||
messages=prompt, # type: ignore
|
messages=prompt, # type: ignore
|
||||||
response_format={"type": "json_object"} if json_format else None,
|
response_format=response_format, # type: ignore
|
||||||
)
|
)
|
||||||
return response.choices[0].message.content or ""
|
return response.choices[0].message.content or ""
|
||||||
else:
|
else:
|
||||||
|
@ -136,7 +143,10 @@ class LlmCallBlock(Block):
|
||||||
prompt.append({"role": "system", "content": input_data.sys_prompt})
|
prompt.append({"role": "system", "content": input_data.sys_prompt})
|
||||||
|
|
||||||
if input_data.expected_format:
|
if input_data.expected_format:
|
||||||
expected_format = [f'"{k}": "{v}"' for k, v in input_data.expected_format.items()]
|
expected_format = [
|
||||||
|
f'"{k}": "{v}"' for k, v in
|
||||||
|
input_data.expected_format.items()
|
||||||
|
]
|
||||||
format_prompt = ",\n ".join(expected_format)
|
format_prompt = ",\n ".join(expected_format)
|
||||||
sys_prompt = trim_prompt(f"""
|
sys_prompt = trim_prompt(f"""
|
||||||
|Reply in json format:
|
|Reply in json format:
|
||||||
|
@ -199,3 +209,113 @@ class LlmCallBlock(Block):
|
||||||
retry_prompt = f"Error calling LLM: {e}"
|
retry_prompt = f"Error calling LLM: {e}"
|
||||||
|
|
||||||
yield "error", retry_prompt
|
yield "error", retry_prompt
|
||||||
|
|
||||||
|
|
||||||
|
class TextSummarizerBlock(Block):
|
||||||
|
class Input(BlockSchema):
|
||||||
|
text: str
|
||||||
|
model: LlmModel = LlmModel.GPT4_TURBO
|
||||||
|
api_key: BlockFieldSecret = BlockFieldSecret(value="")
|
||||||
|
# TODO: Make this dynamic
|
||||||
|
max_tokens: int = 4000 # Adjust based on the model's context window
|
||||||
|
chunk_overlap: int = 100 # Overlap between chunks to maintain context
|
||||||
|
|
||||||
|
class Output(BlockSchema):
|
||||||
|
summary: str
|
||||||
|
error: str
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="c3d4e5f6-7g8h-9i0j-1k2l-m3n4o5p6q7r8",
|
||||||
|
input_schema=TextSummarizerBlock.Input,
|
||||||
|
output_schema=TextSummarizerBlock.Output,
|
||||||
|
test_input={"text": "Lorem ipsum..." * 100},
|
||||||
|
test_output=("summary", "Final summary of a long text"),
|
||||||
|
test_mock={
|
||||||
|
"llm_call": lambda input_data:
|
||||||
|
{"final_summary": "Final summary of a long text"}
|
||||||
|
if "final_summary" in input_data.expected_format
|
||||||
|
else {"summary": "Summary of a chunk of text"}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def run(self, input_data: Input) -> BlockOutput:
|
||||||
|
try:
|
||||||
|
for output in self._run(input_data):
|
||||||
|
yield output
|
||||||
|
except Exception as e:
|
||||||
|
yield "error", str(e)
|
||||||
|
|
||||||
|
def _run(self, input_data: Input) -> BlockOutput:
|
||||||
|
chunks = self._split_text(
|
||||||
|
input_data.text,
|
||||||
|
input_data.max_tokens,
|
||||||
|
input_data.chunk_overlap
|
||||||
|
)
|
||||||
|
summaries = []
|
||||||
|
|
||||||
|
for chunk in chunks:
|
||||||
|
chunk_summary = self._summarize_chunk(chunk, input_data)
|
||||||
|
summaries.append(chunk_summary)
|
||||||
|
|
||||||
|
final_summary = self._combine_summaries(summaries, input_data)
|
||||||
|
yield "summary", final_summary
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _split_text(text: str, max_tokens: int, overlap: int) -> list[str]:
|
||||||
|
words = text.split()
|
||||||
|
chunks = []
|
||||||
|
chunk_size = max_tokens - overlap
|
||||||
|
|
||||||
|
for i in range(0, len(words), chunk_size):
|
||||||
|
chunk = " ".join(words[i:i + max_tokens])
|
||||||
|
chunks.append(chunk)
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def llm_call(input_data: LlmCallBlock.Input) -> dict[str, str]:
|
||||||
|
llm_block = LlmCallBlock()
|
||||||
|
for output_name, output_data in llm_block.run(input_data):
|
||||||
|
if output_name == "response":
|
||||||
|
return output_data
|
||||||
|
raise ValueError("Failed to get a response from the LLM.")
|
||||||
|
|
||||||
|
def _summarize_chunk(self, chunk: str, input_data: Input) -> str:
|
||||||
|
prompt = f"Summarize the following text concisely:\n\n{chunk}"
|
||||||
|
|
||||||
|
llm_response = self.llm_call(LlmCallBlock.Input(
|
||||||
|
prompt=prompt,
|
||||||
|
api_key=input_data.api_key,
|
||||||
|
model=input_data.model,
|
||||||
|
expected_format={"summary": "The summary of the given text."}
|
||||||
|
))
|
||||||
|
|
||||||
|
return llm_response["summary"]
|
||||||
|
|
||||||
|
def _combine_summaries(self, summaries: list[str], input_data: Input) -> str:
|
||||||
|
combined_text = " ".join(summaries)
|
||||||
|
|
||||||
|
if len(combined_text.split()) <= input_data.max_tokens:
|
||||||
|
prompt = ("Provide a final, concise summary of the following summaries:\n\n"
|
||||||
|
+ combined_text)
|
||||||
|
|
||||||
|
llm_response = self.llm_call(LlmCallBlock.Input(
|
||||||
|
prompt=prompt,
|
||||||
|
api_key=input_data.api_key,
|
||||||
|
model=input_data.model,
|
||||||
|
expected_format={
|
||||||
|
"final_summary": "The final summary of all provided summaries."
|
||||||
|
}
|
||||||
|
))
|
||||||
|
|
||||||
|
return llm_response["final_summary"]
|
||||||
|
else:
|
||||||
|
# If combined summaries are still too long, recursively summarize
|
||||||
|
return self._run(TextSummarizerBlock.Input(
|
||||||
|
text=combined_text,
|
||||||
|
api_key=input_data.api_key,
|
||||||
|
model=input_data.model,
|
||||||
|
max_tokens=input_data.max_tokens,
|
||||||
|
chunk_overlap=input_data.chunk_overlap
|
||||||
|
)).send(None)[1] # Get the first yielded value
|
|
@ -1,6 +1,6 @@
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from autogpt_server.blocks.ai import LlmCallBlock
|
from autogpt_server.blocks.llm import LlmCallBlock
|
||||||
from autogpt_server.blocks.basic import ValueBlock
|
from autogpt_server.blocks.basic import ValueBlock
|
||||||
from autogpt_server.blocks.block import BlockInstallationBlock
|
from autogpt_server.blocks.block import BlockInstallationBlock
|
||||||
from autogpt_server.blocks.http import HttpRequestBlock
|
from autogpt_server.blocks.http import HttpRequestBlock
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from autogpt_server.data.graph import Graph, Link, Node, create_graph
|
from autogpt_server.data.graph import Graph, Link, Node, create_graph
|
||||||
from autogpt_server.blocks.ai import LlmCallBlock
|
from autogpt_server.blocks.llm import LlmCallBlock
|
||||||
from autogpt_server.blocks.reddit import (
|
from autogpt_server.blocks.reddit import (
|
||||||
RedditGetPostsBlock,
|
RedditGetPostsBlock,
|
||||||
RedditPostCommentBlock,
|
RedditPostCommentBlock,
|
||||||
|
|
Loading…
Reference in New Issue