From 23d58a3cc0b7473b1ff2362acf8392eeae73a1d8 Mon Sep 17 00:00:00 2001 From: Reinier van der Leer Date: Fri, 16 Feb 2024 15:17:11 +0100 Subject: [PATCH] feat(benchmark/cli): Add `challenge list`, `challenge info` subcommands - Add `challenge list` command with options `--all`, `--names`, `--json` - Add `tabular` dependency - Add `.utils.utils.sorted_by_enum_index` function to easily sort lists by an enum value/property based on the order of the enum's definition - Add `challenge info [name]` command with option `--json` - Add `.utils.utils.pretty_print_model` routine to pretty-print Pydantic models - Refactor `config` subcommand to use `pretty_print_model` --- benchmark/agbenchmark/__main__.py | 127 +++++++++++++++++++++- benchmark/agbenchmark/utils/data_types.py | 2 +- benchmark/agbenchmark/utils/utils.py | 79 +++++++++++++- benchmark/poetry.lock | 16 ++- benchmark/pyproject.toml | 1 + 5 files changed, 219 insertions(+), 6 deletions(-) diff --git a/benchmark/agbenchmark/__main__.py b/benchmark/agbenchmark/__main__.py index 9fff53523..571f19f35 100644 --- a/benchmark/agbenchmark/__main__.py +++ b/benchmark/agbenchmark/__main__.py @@ -202,15 +202,136 @@ def serve(port: Optional[int] = None): @cli.command() def config(): """Displays info regarding the present AGBenchmark config.""" + from .utils.utils import pretty_print_model + try: config = AgentBenchmarkConfig.load() except FileNotFoundError as e: click.echo(e, err=True) return 1 - k_col_width = max(len(k) for k in config.dict().keys()) - for k, v in config.dict().items(): - click.echo(f"{k: <{k_col_width}} = {v}") + pretty_print_model(config, include_header=False) + + +@cli.group() +def challenge(): + logging.getLogger().setLevel(logging.WARNING) + pass + + +@challenge.command("list") +@click.option( + "--all", "include_unavailable", is_flag=True, help="Include unavailable challenges." +) +@click.option( + "--names", "only_names", is_flag=True, help="List only the challenge names." +) +@click.option("--json", "output_json", is_flag=True) +def list_challenges(include_unavailable: bool, only_names: bool, output_json: bool): + """Lists [available|all] challenges.""" + import json + + from tabulate import tabulate + + from .challenges.builtin import load_builtin_challenges + from .challenges.webarena import load_webarena_challenges + from .utils.data_types import Category, DifficultyLevel + from .utils.utils import sorted_by_enum_index + + DIFFICULTY_COLORS = { + difficulty: color + for difficulty, color in zip( + DifficultyLevel, + ["black", "blue", "cyan", "green", "yellow", "red", "magenta", "white"], + ) + } + CATEGORY_COLORS = { + category: f"bright_{color}" + for category, color in zip( + Category, + ["blue", "cyan", "green", "yellow", "magenta", "red", "white", "black"], + ) + } + + # Load challenges + challenges = filter( + lambda c: c.info.available or include_unavailable, + [ + *load_builtin_challenges(), + *load_webarena_challenges(skip_unavailable=False), + ], + ) + challenges = sorted_by_enum_index( + challenges, DifficultyLevel, key=lambda c: c.info.difficulty + ) + + if only_names: + if output_json: + click.echo(json.dumps([c.info.name for c in challenges])) + return + + for c in challenges: + click.echo( + click.style(c.info.name, fg=None if c.info.available else "black") + ) + return + + if output_json: + click.echo(json.dumps([json.loads(c.info.json()) for c in challenges])) + return + + headers = tuple( + click.style(h, bold=True) for h in ("Name", "Difficulty", "Categories") + ) + table = [ + tuple( + v if challenge.info.available else click.style(v, fg="black") + for v in ( + challenge.info.name, + ( + click.style( + challenge.info.difficulty.value, + fg=DIFFICULTY_COLORS[challenge.info.difficulty], + ) + if challenge.info.difficulty + else click.style("-", fg="black") + ), + " ".join( + click.style(cat.value, fg=CATEGORY_COLORS[cat]) + for cat in sorted_by_enum_index(challenge.info.category, Category) + ), + ) + ) + for challenge in challenges + ] + click.echo(tabulate(table, headers=headers)) + + +@challenge.command() +@click.option("--json", is_flag=True) +@click.argument("name") +def info(name: str, json: bool): + from itertools import chain + + from .challenges.builtin import load_builtin_challenges + from .challenges.webarena import load_webarena_challenges + from .utils.utils import pretty_print_model + + for challenge in chain( + load_builtin_challenges(), + load_webarena_challenges(skip_unavailable=False), + ): + if challenge.info.name != name: + continue + + if json: + click.echo(challenge.info.json()) + break + + pretty_print_model(challenge.info) + break + else: + click.echo(click.style(f"Unknown challenge '{name}'", fg="red"), err=True) @cli.command() diff --git a/benchmark/agbenchmark/utils/data_types.py b/benchmark/agbenchmark/utils/data_types.py index 688209682..ac7444921 100644 --- a/benchmark/agbenchmark/utils/data_types.py +++ b/benchmark/agbenchmark/utils/data_types.py @@ -29,8 +29,8 @@ STRING_DIFFICULTY_MAP = {e.value: DIFFICULTY_MAP[e] for e in DifficultyLevel} class Category(str, Enum): - DATA = "data" GENERALIST = "general" + DATA = "data" CODING = "coding" SCRAPE_SYNTHESIZE = "scrape_synthesize" WEB = "web" diff --git a/benchmark/agbenchmark/utils/utils.py b/benchmark/agbenchmark/utils/utils.py index 93724de85..0f0ad56d9 100644 --- a/benchmark/agbenchmark/utils/utils.py +++ b/benchmark/agbenchmark/utils/utils.py @@ -3,10 +3,13 @@ import json import logging import os import re +from enum import Enum from pathlib import Path -from typing import Any, Optional +from typing import Any, Callable, Iterable, Optional, TypeVar, overload +import click from dotenv import load_dotenv +from pydantic import BaseModel from agbenchmark.reports.processing.report_types import Test from agbenchmark.utils.data_types import DIFFICULTY_MAP, DifficultyLevel @@ -17,6 +20,9 @@ AGENT_NAME = os.getenv("AGENT_NAME") logger = logging.getLogger(__name__) +T = TypeVar("T") +E = TypeVar("E", bound=Enum) + def replace_backslash(value: Any) -> Any: if isinstance(value, str): @@ -124,6 +130,42 @@ def write_pretty_json(data, json_file): f.write("\n") +def pretty_print_model(model: BaseModel, include_header: bool = True) -> None: + indent = "" + if include_header: + # Try to find the ID and/or name attribute of the model + id, name = None, None + for attr, value in model.dict().items(): + if attr == "id" or attr.endswith("_id"): + id = value + if attr.endswith("name"): + name = value + if id and name: + break + identifiers = [v for v in [name, id] if v] + click.echo( + f"{model.__repr_name__()}{repr(identifiers) if identifiers else ''}:" + ) + indent = " " * 2 + + k_col_width = max(len(k) for k in model.dict().keys()) + for k, v in model.dict().items(): + v_fmt = repr(v) + if v is None or v == "": + v_fmt = click.style(v_fmt, fg="black") + elif type(v) is bool: + v_fmt = click.style(v_fmt, fg="green" if v else "red") + elif type(v) is str and "\n" in v: + v_fmt = f"\n{v}".replace( + "\n", f"\n{indent} {click.style('|', fg='black')} " + ) + if isinstance(v, Enum): + v_fmt = click.style(v.value, fg="blue") + elif type(v) is list and len(v) > 0 and isinstance(v[0], Enum): + v_fmt = ", ".join(click.style(lv.value, fg="blue") for lv in v) + click.echo(f"{indent}{k: <{k_col_width}} = {v_fmt}") + + def deep_sort(obj): """ Recursively sort the keys in JSON object @@ -133,3 +175,38 @@ def deep_sort(obj): if isinstance(obj, list): return [deep_sort(elem) for elem in obj] return obj + + +@overload +def sorted_by_enum_index( + sortable: Iterable[E], + enum: type[E], + *, + reverse: bool = False, +) -> list[E]: + ... + + +@overload +def sorted_by_enum_index( + sortable: Iterable[T], + enum: type[Enum], + *, + key: Callable[[T], Enum | None], + reverse: bool = False, +) -> list[T]: + ... + + +def sorted_by_enum_index( + sortable: Iterable[T], + enum: type[Enum], + *, + key: Callable[[T], Enum | None] = lambda x: x, # type: ignore + reverse: bool = False, +) -> list[T]: + return sorted( + sortable, + key=lambda x: enum._member_names_.index(e.name) if (e := key(x)) else 420e3, + reverse=reverse, + ) diff --git a/benchmark/poetry.lock b/benchmark/poetry.lock index 005086565..70bef01f6 100644 --- a/benchmark/poetry.lock +++ b/benchmark/poetry.lock @@ -2431,6 +2431,20 @@ anyio = ">=3.4.0,<5" [package.extras] full = ["httpx (>=0.22.0)", "itsdangerous", "jinja2", "python-multipart (>=0.0.7)", "pyyaml"] +[[package]] +name = "tabulate" +version = "0.9.0" +description = "Pretty-print tabular data" +optional = false +python-versions = ">=3.7" +files = [ + {file = "tabulate-0.9.0-py3-none-any.whl", hash = "sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f"}, + {file = "tabulate-0.9.0.tar.gz", hash = "sha256:0095b12bf5966de529c0feb1fa08671671b3368eec77d7ef7ab114be2c068b3c"}, +] + +[package.extras] +widechars = ["wcwidth"] + [[package]] name = "toml" version = "0.10.2" @@ -2760,4 +2774,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "d7893a88906b5a8eda566e13e6a9492d012c910ded0da1b1ef12b69a14f8e047" +content-hash = "6eefdbbefb500de627cac39eb6eb1fdcecab76dd4c3599cf08ef6dc647cf71c9" diff --git a/benchmark/pyproject.toml b/benchmark/pyproject.toml index c659dcc8b..6c3976743 100644 --- a/benchmark/pyproject.toml +++ b/benchmark/pyproject.toml @@ -34,6 +34,7 @@ toml = "^0.10.2" httpx = "^0.24.0" agent-protocol-client = "^1.1.0" click-default-group = "^1.2.4" +tabulate = "^0.9.0" [tool.poetry.group.dev.dependencies] flake8 = "^3.9.2"