feat(benchmark/cli): Add `challenge list`, `challenge info` subcommands
- Add `challenge list` command with options `--all`, `--names`, `--json` - Add `tabular` dependency - Add `.utils.utils.sorted_by_enum_index` function to easily sort lists by an enum value/property based on the order of the enum's definition - Add `challenge info [name]` command with option `--json` - Add `.utils.utils.pretty_print_model` routine to pretty-print Pydantic models - Refactor `config` subcommand to use `pretty_print_model`pull/6857/head
parent
70e345b2ce
commit
23d58a3cc0
|
@ -202,15 +202,136 @@ def serve(port: Optional[int] = None):
|
|||
@cli.command()
|
||||
def config():
|
||||
"""Displays info regarding the present AGBenchmark config."""
|
||||
from .utils.utils import pretty_print_model
|
||||
|
||||
try:
|
||||
config = AgentBenchmarkConfig.load()
|
||||
except FileNotFoundError as e:
|
||||
click.echo(e, err=True)
|
||||
return 1
|
||||
|
||||
k_col_width = max(len(k) for k in config.dict().keys())
|
||||
for k, v in config.dict().items():
|
||||
click.echo(f"{k: <{k_col_width}} = {v}")
|
||||
pretty_print_model(config, include_header=False)
|
||||
|
||||
|
||||
@cli.group()
|
||||
def challenge():
|
||||
logging.getLogger().setLevel(logging.WARNING)
|
||||
pass
|
||||
|
||||
|
||||
@challenge.command("list")
|
||||
@click.option(
|
||||
"--all", "include_unavailable", is_flag=True, help="Include unavailable challenges."
|
||||
)
|
||||
@click.option(
|
||||
"--names", "only_names", is_flag=True, help="List only the challenge names."
|
||||
)
|
||||
@click.option("--json", "output_json", is_flag=True)
|
||||
def list_challenges(include_unavailable: bool, only_names: bool, output_json: bool):
|
||||
"""Lists [available|all] challenges."""
|
||||
import json
|
||||
|
||||
from tabulate import tabulate
|
||||
|
||||
from .challenges.builtin import load_builtin_challenges
|
||||
from .challenges.webarena import load_webarena_challenges
|
||||
from .utils.data_types import Category, DifficultyLevel
|
||||
from .utils.utils import sorted_by_enum_index
|
||||
|
||||
DIFFICULTY_COLORS = {
|
||||
difficulty: color
|
||||
for difficulty, color in zip(
|
||||
DifficultyLevel,
|
||||
["black", "blue", "cyan", "green", "yellow", "red", "magenta", "white"],
|
||||
)
|
||||
}
|
||||
CATEGORY_COLORS = {
|
||||
category: f"bright_{color}"
|
||||
for category, color in zip(
|
||||
Category,
|
||||
["blue", "cyan", "green", "yellow", "magenta", "red", "white", "black"],
|
||||
)
|
||||
}
|
||||
|
||||
# Load challenges
|
||||
challenges = filter(
|
||||
lambda c: c.info.available or include_unavailable,
|
||||
[
|
||||
*load_builtin_challenges(),
|
||||
*load_webarena_challenges(skip_unavailable=False),
|
||||
],
|
||||
)
|
||||
challenges = sorted_by_enum_index(
|
||||
challenges, DifficultyLevel, key=lambda c: c.info.difficulty
|
||||
)
|
||||
|
||||
if only_names:
|
||||
if output_json:
|
||||
click.echo(json.dumps([c.info.name for c in challenges]))
|
||||
return
|
||||
|
||||
for c in challenges:
|
||||
click.echo(
|
||||
click.style(c.info.name, fg=None if c.info.available else "black")
|
||||
)
|
||||
return
|
||||
|
||||
if output_json:
|
||||
click.echo(json.dumps([json.loads(c.info.json()) for c in challenges]))
|
||||
return
|
||||
|
||||
headers = tuple(
|
||||
click.style(h, bold=True) for h in ("Name", "Difficulty", "Categories")
|
||||
)
|
||||
table = [
|
||||
tuple(
|
||||
v if challenge.info.available else click.style(v, fg="black")
|
||||
for v in (
|
||||
challenge.info.name,
|
||||
(
|
||||
click.style(
|
||||
challenge.info.difficulty.value,
|
||||
fg=DIFFICULTY_COLORS[challenge.info.difficulty],
|
||||
)
|
||||
if challenge.info.difficulty
|
||||
else click.style("-", fg="black")
|
||||
),
|
||||
" ".join(
|
||||
click.style(cat.value, fg=CATEGORY_COLORS[cat])
|
||||
for cat in sorted_by_enum_index(challenge.info.category, Category)
|
||||
),
|
||||
)
|
||||
)
|
||||
for challenge in challenges
|
||||
]
|
||||
click.echo(tabulate(table, headers=headers))
|
||||
|
||||
|
||||
@challenge.command()
|
||||
@click.option("--json", is_flag=True)
|
||||
@click.argument("name")
|
||||
def info(name: str, json: bool):
|
||||
from itertools import chain
|
||||
|
||||
from .challenges.builtin import load_builtin_challenges
|
||||
from .challenges.webarena import load_webarena_challenges
|
||||
from .utils.utils import pretty_print_model
|
||||
|
||||
for challenge in chain(
|
||||
load_builtin_challenges(),
|
||||
load_webarena_challenges(skip_unavailable=False),
|
||||
):
|
||||
if challenge.info.name != name:
|
||||
continue
|
||||
|
||||
if json:
|
||||
click.echo(challenge.info.json())
|
||||
break
|
||||
|
||||
pretty_print_model(challenge.info)
|
||||
break
|
||||
else:
|
||||
click.echo(click.style(f"Unknown challenge '{name}'", fg="red"), err=True)
|
||||
|
||||
|
||||
@cli.command()
|
||||
|
|
|
@ -29,8 +29,8 @@ STRING_DIFFICULTY_MAP = {e.value: DIFFICULTY_MAP[e] for e in DifficultyLevel}
|
|||
|
||||
|
||||
class Category(str, Enum):
|
||||
DATA = "data"
|
||||
GENERALIST = "general"
|
||||
DATA = "data"
|
||||
CODING = "coding"
|
||||
SCRAPE_SYNTHESIZE = "scrape_synthesize"
|
||||
WEB = "web"
|
||||
|
|
|
@ -3,10 +3,13 @@ import json
|
|||
import logging
|
||||
import os
|
||||
import re
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Callable, Iterable, Optional, TypeVar, overload
|
||||
|
||||
import click
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import BaseModel
|
||||
|
||||
from agbenchmark.reports.processing.report_types import Test
|
||||
from agbenchmark.utils.data_types import DIFFICULTY_MAP, DifficultyLevel
|
||||
|
@ -17,6 +20,9 @@ AGENT_NAME = os.getenv("AGENT_NAME")
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
E = TypeVar("E", bound=Enum)
|
||||
|
||||
|
||||
def replace_backslash(value: Any) -> Any:
|
||||
if isinstance(value, str):
|
||||
|
@ -124,6 +130,42 @@ def write_pretty_json(data, json_file):
|
|||
f.write("\n")
|
||||
|
||||
|
||||
def pretty_print_model(model: BaseModel, include_header: bool = True) -> None:
|
||||
indent = ""
|
||||
if include_header:
|
||||
# Try to find the ID and/or name attribute of the model
|
||||
id, name = None, None
|
||||
for attr, value in model.dict().items():
|
||||
if attr == "id" or attr.endswith("_id"):
|
||||
id = value
|
||||
if attr.endswith("name"):
|
||||
name = value
|
||||
if id and name:
|
||||
break
|
||||
identifiers = [v for v in [name, id] if v]
|
||||
click.echo(
|
||||
f"{model.__repr_name__()}{repr(identifiers) if identifiers else ''}:"
|
||||
)
|
||||
indent = " " * 2
|
||||
|
||||
k_col_width = max(len(k) for k in model.dict().keys())
|
||||
for k, v in model.dict().items():
|
||||
v_fmt = repr(v)
|
||||
if v is None or v == "":
|
||||
v_fmt = click.style(v_fmt, fg="black")
|
||||
elif type(v) is bool:
|
||||
v_fmt = click.style(v_fmt, fg="green" if v else "red")
|
||||
elif type(v) is str and "\n" in v:
|
||||
v_fmt = f"\n{v}".replace(
|
||||
"\n", f"\n{indent} {click.style('|', fg='black')} "
|
||||
)
|
||||
if isinstance(v, Enum):
|
||||
v_fmt = click.style(v.value, fg="blue")
|
||||
elif type(v) is list and len(v) > 0 and isinstance(v[0], Enum):
|
||||
v_fmt = ", ".join(click.style(lv.value, fg="blue") for lv in v)
|
||||
click.echo(f"{indent}{k: <{k_col_width}} = {v_fmt}")
|
||||
|
||||
|
||||
def deep_sort(obj):
|
||||
"""
|
||||
Recursively sort the keys in JSON object
|
||||
|
@ -133,3 +175,38 @@ def deep_sort(obj):
|
|||
if isinstance(obj, list):
|
||||
return [deep_sort(elem) for elem in obj]
|
||||
return obj
|
||||
|
||||
|
||||
@overload
|
||||
def sorted_by_enum_index(
|
||||
sortable: Iterable[E],
|
||||
enum: type[E],
|
||||
*,
|
||||
reverse: bool = False,
|
||||
) -> list[E]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def sorted_by_enum_index(
|
||||
sortable: Iterable[T],
|
||||
enum: type[Enum],
|
||||
*,
|
||||
key: Callable[[T], Enum | None],
|
||||
reverse: bool = False,
|
||||
) -> list[T]:
|
||||
...
|
||||
|
||||
|
||||
def sorted_by_enum_index(
|
||||
sortable: Iterable[T],
|
||||
enum: type[Enum],
|
||||
*,
|
||||
key: Callable[[T], Enum | None] = lambda x: x, # type: ignore
|
||||
reverse: bool = False,
|
||||
) -> list[T]:
|
||||
return sorted(
|
||||
sortable,
|
||||
key=lambda x: enum._member_names_.index(e.name) if (e := key(x)) else 420e3,
|
||||
reverse=reverse,
|
||||
)
|
||||
|
|
|
@ -2431,6 +2431,20 @@ anyio = ">=3.4.0,<5"
|
|||
[package.extras]
|
||||
full = ["httpx (>=0.22.0)", "itsdangerous", "jinja2", "python-multipart (>=0.0.7)", "pyyaml"]
|
||||
|
||||
[[package]]
|
||||
name = "tabulate"
|
||||
version = "0.9.0"
|
||||
description = "Pretty-print tabular data"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "tabulate-0.9.0-py3-none-any.whl", hash = "sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f"},
|
||||
{file = "tabulate-0.9.0.tar.gz", hash = "sha256:0095b12bf5966de529c0feb1fa08671671b3368eec77d7ef7ab114be2c068b3c"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
widechars = ["wcwidth"]
|
||||
|
||||
[[package]]
|
||||
name = "toml"
|
||||
version = "0.10.2"
|
||||
|
@ -2760,4 +2774,4 @@ multidict = ">=4.0"
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.10"
|
||||
content-hash = "d7893a88906b5a8eda566e13e6a9492d012c910ded0da1b1ef12b69a14f8e047"
|
||||
content-hash = "6eefdbbefb500de627cac39eb6eb1fdcecab76dd4c3599cf08ef6dc647cf71c9"
|
||||
|
|
|
@ -34,6 +34,7 @@ toml = "^0.10.2"
|
|||
httpx = "^0.24.0"
|
||||
agent-protocol-client = "^1.1.0"
|
||||
click-default-group = "^1.2.4"
|
||||
tabulate = "^0.9.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
flake8 = "^3.9.2"
|
||||
|
|
Loading…
Reference in New Issue