From 920f931a21fa989ece4a0f9ab3d54a16fa019936 Mon Sep 17 00:00:00 2001 From: Toran Bruce Richards <toran.richards@gmail.com> Date: Tue, 16 Jul 2024 11:51:37 +0100 Subject: [PATCH] feat(blocks): Add summariser block for recursive text summarization functionality (#7431) Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co> --- .../autogpt_server/blocks/{ai.py => llm.py} | 132 +++++++++++++++++- .../autogpt_server/usecases/block_autogen.py | 2 +- .../usecases/reddit_marketing.py | 2 +- 3 files changed, 128 insertions(+), 8 deletions(-) rename rnd/autogpt_server/autogpt_server/blocks/{ai.py => llm.py} (59%) diff --git a/rnd/autogpt_server/autogpt_server/blocks/ai.py b/rnd/autogpt_server/autogpt_server/blocks/llm.py similarity index 59% rename from rnd/autogpt_server/autogpt_server/blocks/ai.py rename to rnd/autogpt_server/autogpt_server/blocks/llm.py index 174bbf9c9..7f0cb9692 100644 --- a/rnd/autogpt_server/autogpt_server/blocks/ai.py +++ b/rnd/autogpt_server/autogpt_server/blocks/llm.py @@ -2,8 +2,8 @@ import logging from enum import Enum from typing import NamedTuple -import openai import anthropic +import openai from groq import Groq from autogpt_server.data.block import Block, BlockOutput, BlockSchema, BlockFieldSecret @@ -92,15 +92,21 @@ class LlmCallBlock(Block): ) @staticmethod - def llm_call(api_key: str, model: LlmModel, prompt: list[dict], json_format: bool) -> str: + def llm_call( + api_key: str, + model: LlmModel, + prompt: list[dict], + json_format: bool + ) -> str: provider = model.metadata.provider - + if provider == "openai": openai.api_key = api_key + response_format = {"type": "json_object"} if json_format else None response = openai.chat.completions.create( model=model.value, messages=prompt, # type: ignore - response_format={"type": "json_object"} if json_format else None, # type: ignore + response_format=response_format, # type: ignore ) return response.choices[0].message.content or "" elif provider == "anthropic": @@ -116,10 +122,11 @@ class LlmCallBlock(Block): return response.content[0].text if response.content else "" elif provider == "groq": client = Groq(api_key=api_key) + response_format = {"type": "json_object"} if json_format else None response = client.chat.completions.create( model=model.value, messages=prompt, # type: ignore - response_format={"type": "json_object"} if json_format else None, + response_format=response_format, # type: ignore ) return response.choices[0].message.content or "" else: @@ -136,7 +143,10 @@ class LlmCallBlock(Block): prompt.append({"role": "system", "content": input_data.sys_prompt}) if input_data.expected_format: - expected_format = [f'"{k}": "{v}"' for k, v in input_data.expected_format.items()] + expected_format = [ + f'"{k}": "{v}"' for k, v in + input_data.expected_format.items() + ] format_prompt = ",\n ".join(expected_format) sys_prompt = trim_prompt(f""" |Reply in json format: @@ -199,3 +209,113 @@ class LlmCallBlock(Block): retry_prompt = f"Error calling LLM: {e}" yield "error", retry_prompt + + +class TextSummarizerBlock(Block): + class Input(BlockSchema): + text: str + model: LlmModel = LlmModel.GPT4_TURBO + api_key: BlockFieldSecret = BlockFieldSecret(value="") + # TODO: Make this dynamic + max_tokens: int = 4000 # Adjust based on the model's context window + chunk_overlap: int = 100 # Overlap between chunks to maintain context + + class Output(BlockSchema): + summary: str + error: str + + def __init__(self): + super().__init__( + id="c3d4e5f6-7g8h-9i0j-1k2l-m3n4o5p6q7r8", + input_schema=TextSummarizerBlock.Input, + output_schema=TextSummarizerBlock.Output, + test_input={"text": "Lorem ipsum..." * 100}, + test_output=("summary", "Final summary of a long text"), + test_mock={ + "llm_call": lambda input_data: + {"final_summary": "Final summary of a long text"} + if "final_summary" in input_data.expected_format + else {"summary": "Summary of a chunk of text"} + } + ) + + def run(self, input_data: Input) -> BlockOutput: + try: + for output in self._run(input_data): + yield output + except Exception as e: + yield "error", str(e) + + def _run(self, input_data: Input) -> BlockOutput: + chunks = self._split_text( + input_data.text, + input_data.max_tokens, + input_data.chunk_overlap + ) + summaries = [] + + for chunk in chunks: + chunk_summary = self._summarize_chunk(chunk, input_data) + summaries.append(chunk_summary) + + final_summary = self._combine_summaries(summaries, input_data) + yield "summary", final_summary + + @staticmethod + def _split_text(text: str, max_tokens: int, overlap: int) -> list[str]: + words = text.split() + chunks = [] + chunk_size = max_tokens - overlap + + for i in range(0, len(words), chunk_size): + chunk = " ".join(words[i:i + max_tokens]) + chunks.append(chunk) + + return chunks + + @staticmethod + def llm_call(input_data: LlmCallBlock.Input) -> dict[str, str]: + llm_block = LlmCallBlock() + for output_name, output_data in llm_block.run(input_data): + if output_name == "response": + return output_data + raise ValueError("Failed to get a response from the LLM.") + + def _summarize_chunk(self, chunk: str, input_data: Input) -> str: + prompt = f"Summarize the following text concisely:\n\n{chunk}" + + llm_response = self.llm_call(LlmCallBlock.Input( + prompt=prompt, + api_key=input_data.api_key, + model=input_data.model, + expected_format={"summary": "The summary of the given text."} + )) + + return llm_response["summary"] + + def _combine_summaries(self, summaries: list[str], input_data: Input) -> str: + combined_text = " ".join(summaries) + + if len(combined_text.split()) <= input_data.max_tokens: + prompt = ("Provide a final, concise summary of the following summaries:\n\n" + + combined_text) + + llm_response = self.llm_call(LlmCallBlock.Input( + prompt=prompt, + api_key=input_data.api_key, + model=input_data.model, + expected_format={ + "final_summary": "The final summary of all provided summaries." + } + )) + + return llm_response["final_summary"] + else: + # If combined summaries are still too long, recursively summarize + return self._run(TextSummarizerBlock.Input( + text=combined_text, + api_key=input_data.api_key, + model=input_data.model, + max_tokens=input_data.max_tokens, + chunk_overlap=input_data.chunk_overlap + )).send(None)[1] # Get the first yielded value diff --git a/rnd/autogpt_server/autogpt_server/usecases/block_autogen.py b/rnd/autogpt_server/autogpt_server/usecases/block_autogen.py index f00dc8322..1e3cd22fa 100644 --- a/rnd/autogpt_server/autogpt_server/usecases/block_autogen.py +++ b/rnd/autogpt_server/autogpt_server/usecases/block_autogen.py @@ -1,6 +1,6 @@ from pathlib import Path -from autogpt_server.blocks.ai import LlmCallBlock +from autogpt_server.blocks.llm import LlmCallBlock from autogpt_server.blocks.basic import ValueBlock from autogpt_server.blocks.block import BlockInstallationBlock from autogpt_server.blocks.http import HttpRequestBlock diff --git a/rnd/autogpt_server/autogpt_server/usecases/reddit_marketing.py b/rnd/autogpt_server/autogpt_server/usecases/reddit_marketing.py index d0cc025bc..d9230bbdd 100644 --- a/rnd/autogpt_server/autogpt_server/usecases/reddit_marketing.py +++ b/rnd/autogpt_server/autogpt_server/usecases/reddit_marketing.py @@ -1,5 +1,5 @@ from autogpt_server.data.graph import Graph, Link, Node, create_graph -from autogpt_server.blocks.ai import LlmCallBlock +from autogpt_server.blocks.llm import LlmCallBlock from autogpt_server.blocks.reddit import ( RedditGetPostsBlock, RedditPostCommentBlock,