Fix plugin loading issues (#4888)

* Fix Config model initialization

* Fix basedir determination in install_plugin_dependencies

* Add logging to install_plugin_dependencies()

---------

Co-authored-by: collijk <collijk@uw.edu>
pull/4893/head^2
Reinier van der Leer 2023-07-06 01:05:07 +02:00 committed by GitHub
parent 0c8288b5e1
commit 9cf35010c6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 107 additions and 68 deletions

View File

@ -4,7 +4,7 @@ from __future__ import annotations
import contextlib import contextlib
import os import os
import re import re
from typing import Dict from typing import Dict, Union
import yaml import yaml
from colorama import Fore from colorama import Fore
@ -83,18 +83,6 @@ class Config(SystemSettings):
plugins: list[str] plugins: list[str]
authorise_key: str authorise_key: str
def __init__(self, **kwargs):
super().__init__(**kwargs)
# Hotfix: Call model_post_init explictly as it doesn't seem to be called for pydantic<2.0.0
# https://github.com/pydantic/pydantic/issues/1729#issuecomment-1300576214
self.model_post_init(**kwargs)
# Executed immediately after init by Pydantic
def model_post_init(self, **kwargs) -> None:
if not self.plugins_config.plugins:
self.plugins_config = PluginsConfig.load_config(self)
class ConfigBuilder(Configurable[Config]): class ConfigBuilder(Configurable[Config]):
default_plugins_config_file = os.path.join( default_plugins_config_file = os.path.join(
@ -213,21 +201,16 @@ class ConfigBuilder(Configurable[Config]):
"chat_messages_enabled": os.getenv("CHAT_MESSAGES_ENABLED") == "True", "chat_messages_enabled": os.getenv("CHAT_MESSAGES_ENABLED") == "True",
} }
# Converting to a list from comma-separated string config_dict["disabled_command_categories"] = _safe_split(
disabled_command_categories = os.getenv("DISABLED_COMMAND_CATEGORIES") os.getenv("DISABLED_COMMAND_CATEGORIES")
if disabled_command_categories: )
config_dict[
"disabled_command_categories"
] = disabled_command_categories.split(",")
# Converting to a list from comma-separated string config_dict["shell_denylist"] = _safe_split(
shell_denylist = os.getenv("SHELL_DENYLIST", os.getenv("DENY_COMMANDS")) os.getenv("SHELL_DENYLIST", os.getenv("DENY_COMMANDS"))
if shell_denylist: )
config_dict["shell_denylist"] = shell_denylist.split(",") config_dict["shell_allowlist"] = _safe_split(
os.getenv("SHELL_ALLOWLIST", os.getenv("ALLOW_COMMANDS"))
shell_allowlist = os.getenv("SHELL_ALLOWLIST", os.getenv("ALLOW_COMMANDS")) )
if shell_allowlist:
config_dict["shell_allowlist"] = shell_allowlist.split(",")
config_dict["google_custom_search_engine_id"] = os.getenv( config_dict["google_custom_search_engine_id"] = os.getenv(
"GOOGLE_CUSTOM_SEARCH_ENGINE_ID", os.getenv("CUSTOM_SEARCH_ENGINE_ID") "GOOGLE_CUSTOM_SEARCH_ENGINE_ID", os.getenv("CUSTOM_SEARCH_ENGINE_ID")
@ -237,13 +220,13 @@ class ConfigBuilder(Configurable[Config]):
"ELEVENLABS_VOICE_ID", os.getenv("ELEVENLABS_VOICE_1_ID") "ELEVENLABS_VOICE_ID", os.getenv("ELEVENLABS_VOICE_1_ID")
) )
plugins_allowlist = os.getenv("ALLOWLISTED_PLUGINS") config_dict["plugins_allowlist"] = _safe_split(os.getenv("ALLOWLISTED_PLUGINS"))
if plugins_allowlist: config_dict["plugins_denylist"] = _safe_split(os.getenv("DENYLISTED_PLUGINS"))
config_dict["plugins_allowlist"] = plugins_allowlist.split(",") config_dict["plugins_config"] = PluginsConfig.load_config(
config_dict["plugins_config_file"],
plugins_denylist = os.getenv("DENYLISTED_PLUGINS") config_dict["plugins_denylist"],
if plugins_denylist: config_dict["plugins_allowlist"],
config_dict["plugins_denylist"] = plugins_denylist.split(",") )
with contextlib.suppress(TypeError): with contextlib.suppress(TypeError):
config_dict["image_size"] = int(os.getenv("IMAGE_SIZE")) config_dict["image_size"] = int(os.getenv("IMAGE_SIZE"))
@ -325,3 +308,10 @@ def check_openai_api_key(config: Config) -> None:
else: else:
print("Invalid OpenAI API key!") print("Invalid OpenAI API key!")
exit(1) exit(1)
def _safe_split(s: Union[str, None], sep: str = ",") -> list[str]:
"""Split a string by a separator. Return an empty list if the string is None."""
if s is None:
return []
return s.split(sep)

View File

@ -1,13 +1,9 @@
from __future__ import annotations from __future__ import annotations
import os import os
from typing import TYPE_CHECKING, Union from typing import Union
import yaml import yaml
if TYPE_CHECKING:
from autogpt.config import Config
from pydantic import BaseModel from pydantic import BaseModel
from autogpt.logs import logger from autogpt.logs import logger
@ -30,11 +26,20 @@ class PluginsConfig(BaseModel):
return plugin_config is not None and plugin_config.enabled return plugin_config is not None and plugin_config.enabled
@classmethod @classmethod
def load_config(cls, global_config: Config) -> "PluginsConfig": def load_config(
cls,
plugins_config_file: str,
plugins_denylist: list[str],
plugins_allowlist: list[str],
) -> "PluginsConfig":
empty_config = cls(plugins={}) empty_config = cls(plugins={})
try: try:
config_data = cls.deserialize_config_file(global_config=global_config) config_data = cls.deserialize_config_file(
plugins_config_file,
plugins_denylist,
plugins_allowlist,
)
if type(config_data) != dict: if type(config_data) != dict:
logger.error( logger.error(
f"Expected plugins config to be a dict, got {type(config_data)}, continuing without plugins" f"Expected plugins config to be a dict, got {type(config_data)}, continuing without plugins"
@ -49,13 +54,21 @@ class PluginsConfig(BaseModel):
return empty_config return empty_config
@classmethod @classmethod
def deserialize_config_file(cls, global_config: Config) -> dict[str, PluginConfig]: def deserialize_config_file(
plugins_config_path = global_config.plugins_config_file cls,
if not os.path.exists(plugins_config_path): plugins_config_file: str,
plugins_denylist: list[str],
plugins_allowlist: list[str],
) -> dict[str, PluginConfig]:
if not os.path.exists(plugins_config_file):
logger.warn("plugins_config.yaml does not exist, creating base config.") logger.warn("plugins_config.yaml does not exist, creating base config.")
cls.create_empty_plugins_config(global_config=global_config) cls.create_empty_plugins_config(
plugins_config_file,
plugins_denylist,
plugins_allowlist,
)
with open(plugins_config_path, "r") as f: with open(plugins_config_file, "r") as f:
plugins_config = yaml.load(f, Loader=yaml.FullLoader) plugins_config = yaml.load(f, Loader=yaml.FullLoader)
plugins = {} plugins = {}
@ -73,23 +86,27 @@ class PluginsConfig(BaseModel):
return plugins return plugins
@staticmethod @staticmethod
def create_empty_plugins_config(global_config: Config): def create_empty_plugins_config(
plugins_config_file: str,
plugins_denylist: list[str],
plugins_allowlist: list[str],
):
"""Create an empty plugins_config.yaml file. Fill it with values from old env variables.""" """Create an empty plugins_config.yaml file. Fill it with values from old env variables."""
base_config = {} base_config = {}
logger.debug(f"Legacy plugin denylist: {global_config.plugins_denylist}") logger.debug(f"Legacy plugin denylist: {plugins_denylist}")
logger.debug(f"Legacy plugin allowlist: {global_config.plugins_allowlist}") logger.debug(f"Legacy plugin allowlist: {plugins_allowlist}")
# Backwards-compatibility shim # Backwards-compatibility shim
for plugin_name in global_config.plugins_denylist: for plugin_name in plugins_denylist:
base_config[plugin_name] = {"enabled": False, "config": {}} base_config[plugin_name] = {"enabled": False, "config": {}}
for plugin_name in global_config.plugins_allowlist: for plugin_name in plugins_allowlist:
base_config[plugin_name] = {"enabled": True, "config": {}} base_config[plugin_name] = {"enabled": True, "config": {}}
logger.debug(f"Constructed base plugins config: {base_config}") logger.debug(f"Constructed base plugins config: {base_config}")
logger.debug(f"Creating plugin config file {global_config.plugins_config_file}") logger.debug(f"Creating plugin config file {plugins_config_file}")
with open(global_config.plugins_config_file, "w+") as f: with open(plugins_config_file, "w+") as f:
f.write(yaml.dump(base_config)) f.write(yaml.dump(base_config))
return base_config return base_config

View File

@ -5,6 +5,8 @@ import zipfile
from glob import glob from glob import glob
from pathlib import Path from pathlib import Path
from autogpt.logs import logger
def install_plugin_dependencies(): def install_plugin_dependencies():
""" """
@ -18,28 +20,46 @@ def install_plugin_dependencies():
""" """
plugins_dir = Path(os.getenv("PLUGINS_DIR", "plugins")) plugins_dir = Path(os.getenv("PLUGINS_DIR", "plugins"))
logger.debug(f"Checking for dependencies in zipped plugins...")
# Install zip-based plugins # Install zip-based plugins
for plugin in plugins_dir.glob("*.zip"): for plugin_archive in plugins_dir.glob("*.zip"):
with zipfile.ZipFile(str(plugin), "r") as zfile: logger.debug(f"Checking for requirements in '{plugin_archive}'...")
try: with zipfile.ZipFile(str(plugin_archive), "r") as zfile:
basedir = zfile.namelist()[0] if not zfile.namelist():
basereqs = os.path.join(basedir, "requirements.txt")
extracted = zfile.extract(basereqs, path=plugins_dir)
subprocess.check_call(
[sys.executable, "-m", "pip", "install", "-r", extracted]
)
os.remove(extracted)
os.rmdir(os.path.join(plugins_dir, basedir))
except KeyError:
continue continue
# Assume the first entry in the list will be (in) the lowest common dir
first_entry = zfile.namelist()[0]
basedir = first_entry.rsplit("/", 1)[0] if "/" in first_entry else ""
logger.debug(f"Looking for requirements.txt in '{basedir}'")
basereqs = os.path.join(basedir, "requirements.txt")
try:
extracted = zfile.extract(basereqs, path=plugins_dir)
except KeyError as e:
logger.debug(e.args[0])
continue
logger.debug(f"Installing dependencies from '{basereqs}'...")
subprocess.check_call(
[sys.executable, "-m", "pip", "install", "-r", extracted]
)
os.remove(extracted)
os.rmdir(os.path.join(plugins_dir, basedir))
logger.debug(f"Checking for dependencies in other plugin folders...")
# Install directory-based plugins # Install directory-based plugins
for requirements_file in glob(f"{plugins_dir}/*/requirements.txt"): for requirements_file in glob(f"{plugins_dir}/*/requirements.txt"):
logger.debug(f"Installing dependencies from '{requirements_file}'...")
subprocess.check_call( subprocess.check_call(
[sys.executable, "-m", "pip", "install", "-r", requirements_file], [sys.executable, "-m", "pip", "install", "-r", requirements_file],
stdout=subprocess.DEVNULL, stdout=subprocess.DEVNULL,
) )
logger.debug("Finished installing plugin dependencies")
if __name__ == "__main__": if __name__ == "__main__":
install_plugin_dependencies() install_plugin_dependencies()

View File

@ -59,7 +59,11 @@ def config(
# avoid circular dependency # avoid circular dependency
from autogpt.plugins.plugins_config import PluginsConfig from autogpt.plugins.plugins_config import PluginsConfig
config.plugins_config = PluginsConfig.load_config(global_config=config) config.plugins_config = PluginsConfig.load_config(
plugins_config_file=config.plugins_config_file,
plugins_denylist=config.plugins_denylist,
plugins_allowlist=config.plugins_allowlist,
)
# Do a little setup and teardown since the config object is a singleton # Do a little setup and teardown since the config object is a singleton
mocker.patch.multiple( mocker.patch.multiple(

View File

@ -70,7 +70,11 @@ def test_create_base_config(config: Config):
config.plugins_denylist = ["c", "d"] config.plugins_denylist = ["c", "d"]
os.remove(config.plugins_config_file) os.remove(config.plugins_config_file)
plugins_config = PluginsConfig.load_config(global_config=config) plugins_config = PluginsConfig.load_config(
plugins_config_file=config.plugins_config_file,
plugins_denylist=config.plugins_denylist,
plugins_allowlist=config.plugins_allowlist,
)
# Check the structure of the plugins config data # Check the structure of the plugins config data
assert len(plugins_config.plugins) == 4 assert len(plugins_config.plugins) == 4
@ -102,7 +106,11 @@ def test_load_config(config: Config):
f.write(yaml.dump(test_config)) f.write(yaml.dump(test_config))
# Load the config from disk # Load the config from disk
plugins_config = PluginsConfig.load_config(global_config=config) plugins_config = PluginsConfig.load_config(
plugins_config_file=config.plugins_config_file,
plugins_denylist=config.plugins_denylist,
plugins_allowlist=config.plugins_allowlist,
)
# Check that the loaded config is equal to the test config # Check that the loaded config is equal to the test config
assert len(plugins_config.plugins) == 2 assert len(plugins_config.plugins) == 2