feat(benchmark/cli): Add `challenge list`, `challenge info` subcommands

- Add `challenge list` command with options `--all`, `--names`, `--json`
   - Add `tabular` dependency
   - Add `.utils.utils.sorted_by_enum_index` function to easily sort lists by an enum value/property based on the order of the enum's definition
- Add `challenge info [name]` command with option `--json`
   - Add `.utils.utils.pretty_print_model` routine to pretty-print Pydantic models
- Refactor `config` subcommand to use `pretty_print_model`
pull/6857/head
Reinier van der Leer 2024-02-16 15:17:11 +01:00
parent 70e345b2ce
commit 23d58a3cc0
No known key found for this signature in database
GPG Key ID: CDC1180FDAE06193
5 changed files with 219 additions and 6 deletions

View File

@ -202,15 +202,136 @@ def serve(port: Optional[int] = None):
@cli.command()
def config():
"""Displays info regarding the present AGBenchmark config."""
from .utils.utils import pretty_print_model
try:
config = AgentBenchmarkConfig.load()
except FileNotFoundError as e:
click.echo(e, err=True)
return 1
k_col_width = max(len(k) for k in config.dict().keys())
for k, v in config.dict().items():
click.echo(f"{k: <{k_col_width}} = {v}")
pretty_print_model(config, include_header=False)
@cli.group()
def challenge():
logging.getLogger().setLevel(logging.WARNING)
pass
@challenge.command("list")
@click.option(
"--all", "include_unavailable", is_flag=True, help="Include unavailable challenges."
)
@click.option(
"--names", "only_names", is_flag=True, help="List only the challenge names."
)
@click.option("--json", "output_json", is_flag=True)
def list_challenges(include_unavailable: bool, only_names: bool, output_json: bool):
"""Lists [available|all] challenges."""
import json
from tabulate import tabulate
from .challenges.builtin import load_builtin_challenges
from .challenges.webarena import load_webarena_challenges
from .utils.data_types import Category, DifficultyLevel
from .utils.utils import sorted_by_enum_index
DIFFICULTY_COLORS = {
difficulty: color
for difficulty, color in zip(
DifficultyLevel,
["black", "blue", "cyan", "green", "yellow", "red", "magenta", "white"],
)
}
CATEGORY_COLORS = {
category: f"bright_{color}"
for category, color in zip(
Category,
["blue", "cyan", "green", "yellow", "magenta", "red", "white", "black"],
)
}
# Load challenges
challenges = filter(
lambda c: c.info.available or include_unavailable,
[
*load_builtin_challenges(),
*load_webarena_challenges(skip_unavailable=False),
],
)
challenges = sorted_by_enum_index(
challenges, DifficultyLevel, key=lambda c: c.info.difficulty
)
if only_names:
if output_json:
click.echo(json.dumps([c.info.name for c in challenges]))
return
for c in challenges:
click.echo(
click.style(c.info.name, fg=None if c.info.available else "black")
)
return
if output_json:
click.echo(json.dumps([json.loads(c.info.json()) for c in challenges]))
return
headers = tuple(
click.style(h, bold=True) for h in ("Name", "Difficulty", "Categories")
)
table = [
tuple(
v if challenge.info.available else click.style(v, fg="black")
for v in (
challenge.info.name,
(
click.style(
challenge.info.difficulty.value,
fg=DIFFICULTY_COLORS[challenge.info.difficulty],
)
if challenge.info.difficulty
else click.style("-", fg="black")
),
" ".join(
click.style(cat.value, fg=CATEGORY_COLORS[cat])
for cat in sorted_by_enum_index(challenge.info.category, Category)
),
)
)
for challenge in challenges
]
click.echo(tabulate(table, headers=headers))
@challenge.command()
@click.option("--json", is_flag=True)
@click.argument("name")
def info(name: str, json: bool):
from itertools import chain
from .challenges.builtin import load_builtin_challenges
from .challenges.webarena import load_webarena_challenges
from .utils.utils import pretty_print_model
for challenge in chain(
load_builtin_challenges(),
load_webarena_challenges(skip_unavailable=False),
):
if challenge.info.name != name:
continue
if json:
click.echo(challenge.info.json())
break
pretty_print_model(challenge.info)
break
else:
click.echo(click.style(f"Unknown challenge '{name}'", fg="red"), err=True)
@cli.command()

View File

@ -29,8 +29,8 @@ STRING_DIFFICULTY_MAP = {e.value: DIFFICULTY_MAP[e] for e in DifficultyLevel}
class Category(str, Enum):
DATA = "data"
GENERALIST = "general"
DATA = "data"
CODING = "coding"
SCRAPE_SYNTHESIZE = "scrape_synthesize"
WEB = "web"

View File

@ -3,10 +3,13 @@ import json
import logging
import os
import re
from enum import Enum
from pathlib import Path
from typing import Any, Optional
from typing import Any, Callable, Iterable, Optional, TypeVar, overload
import click
from dotenv import load_dotenv
from pydantic import BaseModel
from agbenchmark.reports.processing.report_types import Test
from agbenchmark.utils.data_types import DIFFICULTY_MAP, DifficultyLevel
@ -17,6 +20,9 @@ AGENT_NAME = os.getenv("AGENT_NAME")
logger = logging.getLogger(__name__)
T = TypeVar("T")
E = TypeVar("E", bound=Enum)
def replace_backslash(value: Any) -> Any:
if isinstance(value, str):
@ -124,6 +130,42 @@ def write_pretty_json(data, json_file):
f.write("\n")
def pretty_print_model(model: BaseModel, include_header: bool = True) -> None:
indent = ""
if include_header:
# Try to find the ID and/or name attribute of the model
id, name = None, None
for attr, value in model.dict().items():
if attr == "id" or attr.endswith("_id"):
id = value
if attr.endswith("name"):
name = value
if id and name:
break
identifiers = [v for v in [name, id] if v]
click.echo(
f"{model.__repr_name__()}{repr(identifiers) if identifiers else ''}:"
)
indent = " " * 2
k_col_width = max(len(k) for k in model.dict().keys())
for k, v in model.dict().items():
v_fmt = repr(v)
if v is None or v == "":
v_fmt = click.style(v_fmt, fg="black")
elif type(v) is bool:
v_fmt = click.style(v_fmt, fg="green" if v else "red")
elif type(v) is str and "\n" in v:
v_fmt = f"\n{v}".replace(
"\n", f"\n{indent} {click.style('|', fg='black')} "
)
if isinstance(v, Enum):
v_fmt = click.style(v.value, fg="blue")
elif type(v) is list and len(v) > 0 and isinstance(v[0], Enum):
v_fmt = ", ".join(click.style(lv.value, fg="blue") for lv in v)
click.echo(f"{indent}{k: <{k_col_width}} = {v_fmt}")
def deep_sort(obj):
"""
Recursively sort the keys in JSON object
@ -133,3 +175,38 @@ def deep_sort(obj):
if isinstance(obj, list):
return [deep_sort(elem) for elem in obj]
return obj
@overload
def sorted_by_enum_index(
sortable: Iterable[E],
enum: type[E],
*,
reverse: bool = False,
) -> list[E]:
...
@overload
def sorted_by_enum_index(
sortable: Iterable[T],
enum: type[Enum],
*,
key: Callable[[T], Enum | None],
reverse: bool = False,
) -> list[T]:
...
def sorted_by_enum_index(
sortable: Iterable[T],
enum: type[Enum],
*,
key: Callable[[T], Enum | None] = lambda x: x, # type: ignore
reverse: bool = False,
) -> list[T]:
return sorted(
sortable,
key=lambda x: enum._member_names_.index(e.name) if (e := key(x)) else 420e3,
reverse=reverse,
)

16
benchmark/poetry.lock generated
View File

@ -2431,6 +2431,20 @@ anyio = ">=3.4.0,<5"
[package.extras]
full = ["httpx (>=0.22.0)", "itsdangerous", "jinja2", "python-multipart (>=0.0.7)", "pyyaml"]
[[package]]
name = "tabulate"
version = "0.9.0"
description = "Pretty-print tabular data"
optional = false
python-versions = ">=3.7"
files = [
{file = "tabulate-0.9.0-py3-none-any.whl", hash = "sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f"},
{file = "tabulate-0.9.0.tar.gz", hash = "sha256:0095b12bf5966de529c0feb1fa08671671b3368eec77d7ef7ab114be2c068b3c"},
]
[package.extras]
widechars = ["wcwidth"]
[[package]]
name = "toml"
version = "0.10.2"
@ -2760,4 +2774,4 @@ multidict = ">=4.0"
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
content-hash = "d7893a88906b5a8eda566e13e6a9492d012c910ded0da1b1ef12b69a14f8e047"
content-hash = "6eefdbbefb500de627cac39eb6eb1fdcecab76dd4c3599cf08ef6dc647cf71c9"

View File

@ -34,6 +34,7 @@ toml = "^0.10.2"
httpx = "^0.24.0"
agent-protocol-client = "^1.1.0"
click-default-group = "^1.2.4"
tabulate = "^0.9.0"
[tool.poetry.group.dev.dependencies]
flake8 = "^3.9.2"