fix(agent): Fix and windows-proof `scan_plugins`

- Improve error output for failure to load plugin
- Fix logic to determine qualified module name
- Use `importlib` rather than `__import__` magic function

This unbreaks `scan_plugins` on Windows.
pull/7033/head
Reinier van der Leer 2024-03-20 16:36:43 +01:00
parent 03ffb50dcf
commit a7c0440e9b
No known key found for this signature in database
GPG Key ID: CDC1180FDAE06193
1 changed files with 11 additions and 12 deletions

View File

@ -6,7 +6,6 @@ import inspect
import json import json
import logging import logging
import os import os
import sys
import zipfile import zipfile
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, List from typing import TYPE_CHECKING, List
@ -220,21 +219,21 @@ def scan_plugins(config: Config) -> List[AutoGPTPluginTemplate]:
plugins_config = config.plugins_config plugins_config = config.plugins_config
# Directory-based plugins # Directory-based plugins
for plugin_path in [f.path for f in os.scandir(config.plugins_dir) if f.is_dir()]: for plugin_path in [f for f in Path(config.plugins_dir).iterdir() if f.is_dir()]:
# Avoid going into __pycache__ or other hidden directories # Avoid going into __pycache__ or other hidden directories
if plugin_path.startswith("__"): if plugin_path.name.startswith("__"):
continue continue
plugin_module_path = plugin_path.split(os.path.sep) plugin_module_name = plugin_path.name
plugin_module_name = plugin_module_path[-1] qualified_module_name = ".".join(plugin_path.parts)
qualified_module_name = ".".join(plugin_module_path)
try: try:
__import__(qualified_module_name) plugin = importlib.import_module(qualified_module_name)
except ImportError: except ImportError as e:
logger.error(f"Failed to load {qualified_module_name}") logger.error(
f"Failed to load {qualified_module_name} from {plugin_path}: {e}"
)
continue continue
plugin = sys.modules[qualified_module_name]
if not plugins_config.is_enabled(plugin_module_name): if not plugins_config.is_enabled(plugin_module_name):
logger.warning( logger.warning(
@ -261,8 +260,8 @@ def scan_plugins(config: Config) -> List[AutoGPTPluginTemplate]:
zipped_package = zipimporter(str(plugin)) zipped_package = zipimporter(str(plugin))
try: try:
zipped_module = zipped_package.load_module(str(module.parent)) zipped_module = zipped_package.load_module(str(module.parent))
except ZipImportError: except ZipImportError as e:
logger.error(f"Failed to load {str(module.parent)}") logger.error(f"Failed to load {module.parent} from {plugin}: {e}")
continue continue
for key in dir(zipped_module): for key in dir(zipped_module):