Handle requirements for scripts (#2765)

pull/2779/head
Johann Kellerman 2016-08-10 08:54:34 +02:00 committed by Paulus Schoutsen
parent a03691455b
commit aadf6a7750
5 changed files with 56 additions and 12 deletions

View File

@ -232,7 +232,7 @@ def from_config_dict(config: Dict[str, Any],
if config_dir is not None:
config_dir = os.path.abspath(config_dir)
hass.config.config_dir = config_dir
_mount_local_lib_path(config_dir)
mount_local_lib_path(config_dir)
core_config = config.get(core.DOMAIN, {})
@ -300,7 +300,7 @@ def from_config_file(config_path: str,
# Set config dir to directory holding config file
config_dir = os.path.abspath(os.path.dirname(config_path))
hass.config.config_dir = config_dir
_mount_local_lib_path(config_dir)
mount_local_lib_path(config_dir)
enable_logging(hass, verbose, log_rotate_days)
@ -371,11 +371,6 @@ def _ensure_loader_prepared(hass: core.HomeAssistant) -> None:
loader.prepare(hass)
def _mount_local_lib_path(config_dir: str) -> None:
"""Add local library to Python Path."""
sys.path.insert(0, os.path.join(config_dir, 'deps'))
def _log_exception(ex, domain, config):
"""Generate log exception for config validation."""
message = 'Invalid config for [{}]: '.format(domain)
@ -391,3 +386,11 @@ def _log_exception(ex, domain, config):
config.__line__ or '?')
_LOGGER.error(message)
def mount_local_lib_path(config_dir: str) -> str:
"""Add local library to Python Path."""
deps_dir = os.path.join(config_dir, 'deps')
if deps_dir not in sys.path:
sys.path.insert(0, os.path.join(config_dir, 'deps'))
return deps_dir

View File

@ -1,9 +1,15 @@
"""Home Assistant command line scripts."""
import argparse
import importlib
import os
from typing import List
from homeassistant.config import get_default_config_dir
from homeassistant.util.package import install_package
from homeassistant.bootstrap import mount_local_lib_path
def run(args: str) -> int:
def run(args: List) -> int:
"""Run a script."""
scripts = []
path = os.path.dirname(__file__)
@ -26,4 +32,21 @@ def run(args: str) -> int:
return 1
script = importlib.import_module('homeassistant.scripts.' + args[0])
config_dir = extract_config_dir()
deps_dir = mount_local_lib_path(config_dir)
for req in getattr(script, 'REQUIREMENTS', []):
if not install_package(req, target=deps_dir):
print('Aborting scipt, could not install dependency', req)
return 1
return script.run(args[1:]) # type: ignore
def extract_config_dir(args=None) -> str:
"""Extract the config dir from the arguments or get the default."""
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', default=None)
args = parser.parse_known_args(args)[0]
return (os.path.join(os.getcwd(), args.config) if args.config
else get_default_config_dir())

View File

@ -0,0 +1 @@
"""Tests for the scripts."""

View File

@ -0,0 +1,19 @@
"""Test script init."""
import unittest
from unittest.mock import patch
import homeassistant.scripts as scripts
class TestScripts(unittest.TestCase):
"""Tests homeassistant.scripts module."""
@patch('homeassistant.scripts.get_default_config_dir',
return_value='/default')
def test_config_per_platform(self, mock_def):
"""Test config per platform method."""
self.assertEquals(scripts.get_default_config_dir(), '/default')
self.assertEqual(scripts.extract_config_dir(), '/default')
self.assertEqual(scripts.extract_config_dir(['']), '/default')
self.assertEqual(scripts.extract_config_dir(['-c', '/arg']), '/arg')
self.assertEqual(scripts.extract_config_dir(['--config', '/a']), '/a')

View File

@ -3,7 +3,7 @@ import os
import tempfile
import unittest
import homeassistant.bootstrap as bootstrap
from homeassistant.bootstrap import mount_local_lib_path
import homeassistant.util.package as package
RESOURCE_DIR = os.path.abspath(
@ -21,7 +21,7 @@ class TestPackageUtil(unittest.TestCase):
def setUp(self):
"""Create local library for testing."""
self.tmp_dir = tempfile.TemporaryDirectory()
self.lib_dir = os.path.join(self.tmp_dir.name, 'deps')
self.lib_dir = mount_local_lib_path(self.tmp_dir.name)
def tearDown(self):
"""Stop everything that was started."""
@ -49,8 +49,6 @@ class TestPackageUtil(unittest.TestCase):
self.assertTrue(package.check_package_exists(
TEST_NEW_REQ, self.lib_dir))
bootstrap._mount_local_lib_path(self.tmp_dir.name)
try:
import pyhelloworld3
except ImportError: