From 920f931a21fa989ece4a0f9ab3d54a16fa019936 Mon Sep 17 00:00:00 2001
From: Toran Bruce Richards <toran.richards@gmail.com>
Date: Tue, 16 Jul 2024 11:51:37 +0100
Subject: [PATCH] feat(blocks): Add summariser block for recursive text
 summarization functionality (#7431)

Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
---
 .../autogpt_server/blocks/{ai.py => llm.py}   | 132 +++++++++++++++++-
 .../autogpt_server/usecases/block_autogen.py  |   2 +-
 .../usecases/reddit_marketing.py              |   2 +-
 3 files changed, 128 insertions(+), 8 deletions(-)
 rename rnd/autogpt_server/autogpt_server/blocks/{ai.py => llm.py} (59%)

diff --git a/rnd/autogpt_server/autogpt_server/blocks/ai.py b/rnd/autogpt_server/autogpt_server/blocks/llm.py
similarity index 59%
rename from rnd/autogpt_server/autogpt_server/blocks/ai.py
rename to rnd/autogpt_server/autogpt_server/blocks/llm.py
index 174bbf9c9..7f0cb9692 100644
--- a/rnd/autogpt_server/autogpt_server/blocks/ai.py
+++ b/rnd/autogpt_server/autogpt_server/blocks/llm.py
@@ -2,8 +2,8 @@ import logging
 from enum import Enum
 from typing import NamedTuple
 
-import openai
 import anthropic
+import openai
 from groq import Groq
 
 from autogpt_server.data.block import Block, BlockOutput, BlockSchema, BlockFieldSecret
@@ -92,15 +92,21 @@ class LlmCallBlock(Block):
         )
 
     @staticmethod
-    def llm_call(api_key: str, model: LlmModel, prompt: list[dict], json_format: bool) -> str:
+    def llm_call(
+            api_key: str,
+            model: LlmModel,
+            prompt: list[dict],
+            json_format: bool
+    ) -> str:
         provider = model.metadata.provider
-        
+
         if provider == "openai":
             openai.api_key = api_key
+            response_format = {"type": "json_object"} if json_format else None
             response = openai.chat.completions.create(
                 model=model.value,
                 messages=prompt,  # type: ignore
-                response_format={"type": "json_object"} if json_format else None, # type: ignore
+                response_format=response_format,  # type: ignore
             )
             return response.choices[0].message.content or ""
         elif provider == "anthropic":
@@ -116,10 +122,11 @@ class LlmCallBlock(Block):
             return response.content[0].text if response.content else ""
         elif provider == "groq":
             client = Groq(api_key=api_key)
+            response_format = {"type": "json_object"} if json_format else None
             response = client.chat.completions.create(
                 model=model.value,
                 messages=prompt,  # type: ignore
-                response_format={"type": "json_object"} if json_format else None,
+                response_format=response_format,  # type: ignore
             )
             return response.choices[0].message.content or ""
         else:
@@ -136,7 +143,10 @@ class LlmCallBlock(Block):
             prompt.append({"role": "system", "content": input_data.sys_prompt})
 
         if input_data.expected_format:
-            expected_format = [f'"{k}": "{v}"' for k, v in input_data.expected_format.items()]
+            expected_format = [
+                f'"{k}": "{v}"' for k, v in
+                input_data.expected_format.items()
+            ]
             format_prompt = ",\n  ".join(expected_format)
             sys_prompt = trim_prompt(f"""
               |Reply in json format:
@@ -199,3 +209,113 @@ class LlmCallBlock(Block):
                 retry_prompt = f"Error calling LLM: {e}"
 
         yield "error", retry_prompt
+
+
+class TextSummarizerBlock(Block):
+    class Input(BlockSchema):
+        text: str
+        model: LlmModel = LlmModel.GPT4_TURBO
+        api_key: BlockFieldSecret = BlockFieldSecret(value="")
+        # TODO: Make this dynamic
+        max_tokens: int = 4000  # Adjust based on the model's context window
+        chunk_overlap: int = 100  # Overlap between chunks to maintain context
+
+    class Output(BlockSchema):
+        summary: str
+        error: str
+
+    def __init__(self):
+        super().__init__(
+            id="c3d4e5f6-7g8h-9i0j-1k2l-m3n4o5p6q7r8",
+            input_schema=TextSummarizerBlock.Input,
+            output_schema=TextSummarizerBlock.Output,
+            test_input={"text": "Lorem ipsum..." * 100},
+            test_output=("summary", "Final summary of a long text"),
+            test_mock={
+                "llm_call": lambda input_data:
+                {"final_summary": "Final summary of a long text"}
+                if "final_summary" in input_data.expected_format
+                else {"summary": "Summary of a chunk of text"}
+            }
+        )
+
+    def run(self, input_data: Input) -> BlockOutput:
+        try:
+            for output in self._run(input_data):
+                yield output
+        except Exception as e:
+            yield "error", str(e)
+
+    def _run(self, input_data: Input) -> BlockOutput:
+        chunks = self._split_text(
+            input_data.text,
+            input_data.max_tokens,
+            input_data.chunk_overlap
+        )
+        summaries = []
+
+        for chunk in chunks:
+            chunk_summary = self._summarize_chunk(chunk, input_data)
+            summaries.append(chunk_summary)
+
+        final_summary = self._combine_summaries(summaries, input_data)
+        yield "summary", final_summary
+
+    @staticmethod
+    def _split_text(text: str, max_tokens: int, overlap: int) -> list[str]:
+        words = text.split()
+        chunks = []
+        chunk_size = max_tokens - overlap
+
+        for i in range(0, len(words), chunk_size):
+            chunk = " ".join(words[i:i + max_tokens])
+            chunks.append(chunk)
+
+        return chunks
+
+    @staticmethod
+    def llm_call(input_data: LlmCallBlock.Input) -> dict[str, str]:
+        llm_block = LlmCallBlock()
+        for output_name, output_data in llm_block.run(input_data):
+            if output_name == "response":
+                return output_data
+        raise ValueError("Failed to get a response from the LLM.")
+
+    def _summarize_chunk(self, chunk: str, input_data: Input) -> str:
+        prompt = f"Summarize the following text concisely:\n\n{chunk}"
+
+        llm_response = self.llm_call(LlmCallBlock.Input(
+            prompt=prompt,
+            api_key=input_data.api_key,
+            model=input_data.model,
+            expected_format={"summary": "The summary of the given text."}
+        ))
+
+        return llm_response["summary"]
+
+    def _combine_summaries(self, summaries: list[str], input_data: Input) -> str:
+        combined_text = " ".join(summaries)
+
+        if len(combined_text.split()) <= input_data.max_tokens:
+            prompt = ("Provide a final, concise summary of the following summaries:\n\n"
+                      + combined_text)
+
+            llm_response = self.llm_call(LlmCallBlock.Input(
+                prompt=prompt,
+                api_key=input_data.api_key,
+                model=input_data.model,
+                expected_format={
+                    "final_summary": "The final summary of all provided summaries."
+                }
+            ))
+
+            return llm_response["final_summary"]
+        else:
+            # If combined summaries are still too long, recursively summarize
+            return self._run(TextSummarizerBlock.Input(
+                text=combined_text,
+                api_key=input_data.api_key,
+                model=input_data.model,
+                max_tokens=input_data.max_tokens,
+                chunk_overlap=input_data.chunk_overlap
+            )).send(None)[1]  # Get the first yielded value
diff --git a/rnd/autogpt_server/autogpt_server/usecases/block_autogen.py b/rnd/autogpt_server/autogpt_server/usecases/block_autogen.py
index f00dc8322..1e3cd22fa 100644
--- a/rnd/autogpt_server/autogpt_server/usecases/block_autogen.py
+++ b/rnd/autogpt_server/autogpt_server/usecases/block_autogen.py
@@ -1,6 +1,6 @@
 from pathlib import Path
 
-from autogpt_server.blocks.ai import LlmCallBlock
+from autogpt_server.blocks.llm import LlmCallBlock
 from autogpt_server.blocks.basic import ValueBlock
 from autogpt_server.blocks.block import BlockInstallationBlock
 from autogpt_server.blocks.http import HttpRequestBlock
diff --git a/rnd/autogpt_server/autogpt_server/usecases/reddit_marketing.py b/rnd/autogpt_server/autogpt_server/usecases/reddit_marketing.py
index d0cc025bc..d9230bbdd 100644
--- a/rnd/autogpt_server/autogpt_server/usecases/reddit_marketing.py
+++ b/rnd/autogpt_server/autogpt_server/usecases/reddit_marketing.py
@@ -1,5 +1,5 @@
 from autogpt_server.data.graph import Graph, Link, Node, create_graph
-from autogpt_server.blocks.ai import LlmCallBlock
+from autogpt_server.blocks.llm import LlmCallBlock
 from autogpt_server.blocks.reddit import (
     RedditGetPostsBlock,
     RedditPostCommentBlock,