AutoGPT: Fix decorator typings

pull/5612/head
Reinier van der Leer 2023-10-17 17:20:21 -07:00
parent 12d959f780
commit 6a05e11239
No known key found for this signature in database
GPG Key ID: CDC1180FDAE06193
3 changed files with 19 additions and 12 deletions

View File

@ -2,7 +2,7 @@ from __future__ import annotations
import functools
import inspect
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, ParamSpec, TypeVar
if TYPE_CHECKING:
from autogpt.agents.base import BaseAgent
@ -14,6 +14,9 @@ from autogpt.models.command import Command, CommandOutput, CommandParameter
# Unique identifier for AutoGPT commands
AUTO_GPT_COMMAND_IDENTIFIER = "auto_gpt_command"
P = ParamSpec("P")
CO = TypeVar("CO", bound=CommandOutput)
def command(
name: str,
@ -23,10 +26,10 @@ def command(
disabled_reason: Optional[str] = None,
aliases: list[str] = [],
available: Literal[True] | Callable[[BaseAgent], bool] = True,
) -> Callable[..., CommandOutput]:
) -> Callable[[Callable[P, CO]], Callable[P, CO]]:
"""The command decorator is used to create Command objects from ordinary functions."""
def decorator(func: Callable[..., CommandOutput]):
def decorator(func: Callable[P, CO]) -> Callable[P, CO]:
typed_parameters = [
CommandParameter(
name=param_name,
@ -48,13 +51,13 @@ def command(
if inspect.iscoroutinefunction(func):
@functools.wraps(func)
async def wrapper(*args, **kwargs) -> Any:
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> Any:
return await func(*args, **kwargs)
else:
@functools.wraps(func)
def wrapper(*args, **kwargs) -> Any:
def wrapper(*args: P.args, **kwargs: P.kwargs) -> Any:
return func(*args, **kwargs)
setattr(wrapper, "command", cmd)

View File

@ -2,16 +2,19 @@ import functools
import logging
import re
from pathlib import Path
from typing import Callable
from typing import Callable, ParamSpec, TypeVar
from autogpt.agents.agent import Agent
P = ParamSpec("P")
T = TypeVar("T")
logger = logging.getLogger(__name__)
def sanitize_path_arg(
arg_name: str, make_relative: bool = False
) -> Callable[[Callable], Callable]:
) -> Callable[[Callable[P, T]], Callable[P, T]]:
"""Sanitizes the specified path (str | Path) argument, resolving it to a Path"""
def decorator(func: Callable) -> Callable:
@ -32,7 +35,7 @@ def sanitize_path_arg(
)
@functools.wraps(func)
def wrapper(*args, **kwargs): # type: ignore
def wrapper(*args, **kwargs):
logger.debug(f"Sanitizing arg '{arg_name}' on function '{func.__name__}'")
# Get Agent from the called function's arguments
@ -47,7 +50,7 @@ def sanitize_path_arg(
arg_name, len(args) > arg_index and args[arg_index] or None
)
if given_path:
if type(given_path) == str:
if type(given_path) is str:
# Fix workspace path from output in docker environment
given_path = re.sub(r"^\/workspace", ".", given_path)

View File

@ -1,12 +1,13 @@
import functools
import re
from typing import Any, Callable
from typing import Any, Callable, ParamSpec, TypeVar
from urllib.parse import urljoin, urlparse
from requests.compat import urljoin
P = ParamSpec("P")
T = TypeVar("T")
def validate_url(func: Callable[..., Any]) -> Any:
def validate_url(func: Callable[P, T]) -> Callable[P, T]:
"""The method decorator validate_url is used to validate urls for any command that requires
a url as an argument"""