Handle requirements for scripts (#2765)
parent
a03691455b
commit
aadf6a7750
|
@ -232,7 +232,7 @@ def from_config_dict(config: Dict[str, Any],
|
|||
if config_dir is not None:
|
||||
config_dir = os.path.abspath(config_dir)
|
||||
hass.config.config_dir = config_dir
|
||||
_mount_local_lib_path(config_dir)
|
||||
mount_local_lib_path(config_dir)
|
||||
|
||||
core_config = config.get(core.DOMAIN, {})
|
||||
|
||||
|
@ -300,7 +300,7 @@ def from_config_file(config_path: str,
|
|||
# Set config dir to directory holding config file
|
||||
config_dir = os.path.abspath(os.path.dirname(config_path))
|
||||
hass.config.config_dir = config_dir
|
||||
_mount_local_lib_path(config_dir)
|
||||
mount_local_lib_path(config_dir)
|
||||
|
||||
enable_logging(hass, verbose, log_rotate_days)
|
||||
|
||||
|
@ -371,11 +371,6 @@ def _ensure_loader_prepared(hass: core.HomeAssistant) -> None:
|
|||
loader.prepare(hass)
|
||||
|
||||
|
||||
def _mount_local_lib_path(config_dir: str) -> None:
|
||||
"""Add local library to Python Path."""
|
||||
sys.path.insert(0, os.path.join(config_dir, 'deps'))
|
||||
|
||||
|
||||
def _log_exception(ex, domain, config):
|
||||
"""Generate log exception for config validation."""
|
||||
message = 'Invalid config for [{}]: '.format(domain)
|
||||
|
@ -391,3 +386,11 @@ def _log_exception(ex, domain, config):
|
|||
config.__line__ or '?')
|
||||
|
||||
_LOGGER.error(message)
|
||||
|
||||
|
||||
def mount_local_lib_path(config_dir: str) -> str:
|
||||
"""Add local library to Python Path."""
|
||||
deps_dir = os.path.join(config_dir, 'deps')
|
||||
if deps_dir not in sys.path:
|
||||
sys.path.insert(0, os.path.join(config_dir, 'deps'))
|
||||
return deps_dir
|
||||
|
|
|
@ -1,9 +1,15 @@
|
|||
"""Home Assistant command line scripts."""
|
||||
import argparse
|
||||
import importlib
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
from homeassistant.config import get_default_config_dir
|
||||
from homeassistant.util.package import install_package
|
||||
from homeassistant.bootstrap import mount_local_lib_path
|
||||
|
||||
|
||||
def run(args: str) -> int:
|
||||
def run(args: List) -> int:
|
||||
"""Run a script."""
|
||||
scripts = []
|
||||
path = os.path.dirname(__file__)
|
||||
|
@ -26,4 +32,21 @@ def run(args: str) -> int:
|
|||
return 1
|
||||
|
||||
script = importlib.import_module('homeassistant.scripts.' + args[0])
|
||||
|
||||
config_dir = extract_config_dir()
|
||||
deps_dir = mount_local_lib_path(config_dir)
|
||||
for req in getattr(script, 'REQUIREMENTS', []):
|
||||
if not install_package(req, target=deps_dir):
|
||||
print('Aborting scipt, could not install dependency', req)
|
||||
return 1
|
||||
|
||||
return script.run(args[1:]) # type: ignore
|
||||
|
||||
|
||||
def extract_config_dir(args=None) -> str:
|
||||
"""Extract the config dir from the arguments or get the default."""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-c', '--config', default=None)
|
||||
args = parser.parse_known_args(args)[0]
|
||||
return (os.path.join(os.getcwd(), args.config) if args.config
|
||||
else get_default_config_dir())
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
"""Tests for the scripts."""
|
|
@ -0,0 +1,19 @@
|
|||
"""Test script init."""
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
import homeassistant.scripts as scripts
|
||||
|
||||
|
||||
class TestScripts(unittest.TestCase):
|
||||
"""Tests homeassistant.scripts module."""
|
||||
|
||||
@patch('homeassistant.scripts.get_default_config_dir',
|
||||
return_value='/default')
|
||||
def test_config_per_platform(self, mock_def):
|
||||
"""Test config per platform method."""
|
||||
self.assertEquals(scripts.get_default_config_dir(), '/default')
|
||||
self.assertEqual(scripts.extract_config_dir(), '/default')
|
||||
self.assertEqual(scripts.extract_config_dir(['']), '/default')
|
||||
self.assertEqual(scripts.extract_config_dir(['-c', '/arg']), '/arg')
|
||||
self.assertEqual(scripts.extract_config_dir(['--config', '/a']), '/a')
|
|
@ -3,7 +3,7 @@ import os
|
|||
import tempfile
|
||||
import unittest
|
||||
|
||||
import homeassistant.bootstrap as bootstrap
|
||||
from homeassistant.bootstrap import mount_local_lib_path
|
||||
import homeassistant.util.package as package
|
||||
|
||||
RESOURCE_DIR = os.path.abspath(
|
||||
|
@ -21,7 +21,7 @@ class TestPackageUtil(unittest.TestCase):
|
|||
def setUp(self):
|
||||
"""Create local library for testing."""
|
||||
self.tmp_dir = tempfile.TemporaryDirectory()
|
||||
self.lib_dir = os.path.join(self.tmp_dir.name, 'deps')
|
||||
self.lib_dir = mount_local_lib_path(self.tmp_dir.name)
|
||||
|
||||
def tearDown(self):
|
||||
"""Stop everything that was started."""
|
||||
|
@ -49,8 +49,6 @@ class TestPackageUtil(unittest.TestCase):
|
|||
self.assertTrue(package.check_package_exists(
|
||||
TEST_NEW_REQ, self.lib_dir))
|
||||
|
||||
bootstrap._mount_local_lib_path(self.tmp_dir.name)
|
||||
|
||||
try:
|
||||
import pyhelloworld3
|
||||
except ImportError:
|
||||
|
|
Loading…
Reference in New Issue