Add type hints to requirements script (#82075)

pull/82186/head
epenet 2022-11-16 13:00:35 +01:00 committed by GitHub
parent 1582d88957
commit 0538154767
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 44 additions and 26 deletions

View File

@ -1,5 +1,7 @@
#!/usr/bin/env python3
"""Generate an updated requirements_all.txt."""
"""Generate updated constraint and requirements files."""
from __future__ import annotations
import difflib
import importlib
import os
@ -7,6 +9,7 @@ from pathlib import Path
import pkgutil
import re
import sys
from typing import Any
from homeassistant.util.yaml.loader import load_yaml
from script.hassfest.model import Integration
@ -157,7 +160,7 @@ IGNORE_PRE_COMMIT_HOOK_ID = (
PACKAGE_REGEX = re.compile(r"^(?:--.+\s)?([-_\.\w\d]+).*==.+$")
def has_tests(module: str):
def has_tests(module: str) -> bool:
"""Test if a module has tests.
Module format: homeassistant.components.hue
@ -169,11 +172,11 @@ def has_tests(module: str):
return path.exists()
def explore_module(package, explore_children):
def explore_module(package: str, explore_children: bool) -> list[str]:
"""Explore the modules."""
module = importlib.import_module(package)
found = []
found: list[str] = []
if not hasattr(module, "__path__"):
return found
@ -187,14 +190,17 @@ def explore_module(package, explore_children):
return found
def core_requirements():
def core_requirements() -> list[str]:
"""Gather core requirements out of pyproject.toml."""
with open("pyproject.toml", "rb") as fp:
data = tomllib.load(fp)
return data["project"]["dependencies"]
dependencies: list[str] = data["project"]["dependencies"]
return dependencies
def gather_recursive_requirements(domain, seen=None):
def gather_recursive_requirements(
domain: str, seen: set[str] | None = None
) -> set[str]:
"""Recursively gather requirements from a module."""
if seen is None:
seen = set()
@ -221,18 +227,18 @@ def normalize_package_name(requirement: str) -> str:
return package
def comment_requirement(req):
def comment_requirement(req: str) -> bool:
"""Comment out requirement. Some don't install on all systems."""
return any(
normalize_package_name(req) == ign for ign in COMMENT_REQUIREMENTS_NORMALIZED
)
def gather_modules():
def gather_modules() -> dict[str, list[str]] | None:
"""Collect the information."""
reqs = {}
reqs: dict[str, list[str]] = {}
errors = []
errors: list[str] = []
gather_requirements_from_manifests(errors, reqs)
gather_requirements_from_modules(errors, reqs)
@ -248,7 +254,9 @@ def gather_modules():
return reqs
def gather_requirements_from_manifests(errors, reqs):
def gather_requirements_from_manifests(
errors: list[str], reqs: dict[str, list[str]]
) -> None:
"""Gather all of the requirements from manifests."""
integrations = Integration.load_dir(Path("homeassistant/components"))
for domain in sorted(integrations):
@ -266,7 +274,9 @@ def gather_requirements_from_manifests(errors, reqs):
)
def gather_requirements_from_modules(errors, reqs):
def gather_requirements_from_modules(
errors: list[str], reqs: dict[str, list[str]]
) -> None:
"""Collect the requirements from the modules directly."""
for package in sorted(
explore_module("homeassistant.scripts", True)
@ -283,7 +293,12 @@ def gather_requirements_from_modules(errors, reqs):
process_requirements(errors, module.REQUIREMENTS, package, reqs)
def process_requirements(errors, module_requirements, package, reqs):
def process_requirements(
errors: list[str],
module_requirements: list[str],
package: str,
reqs: dict[str, list[str]],
) -> None:
"""Process all of the requirements."""
for req in module_requirements:
if "://" in req:
@ -293,7 +308,7 @@ def process_requirements(errors, module_requirements, package, reqs):
reqs.setdefault(req, []).append(package)
def generate_requirements_list(reqs):
def generate_requirements_list(reqs: dict[str, list[str]]) -> str:
"""Generate a pip file based on requirements."""
output = []
for pkg, requirements in sorted(reqs.items(), key=lambda item: item[0]):
@ -307,7 +322,7 @@ def generate_requirements_list(reqs):
return "".join(output)
def requirements_output(reqs):
def requirements_output() -> str:
"""Generate output for requirements."""
output = [
"-c homeassistant/package_constraints.txt\n",
@ -320,7 +335,7 @@ def requirements_output(reqs):
return "".join(output)
def requirements_all_output(reqs):
def requirements_all_output(reqs: dict[str, list[str]]) -> str:
"""Generate output for requirements_all."""
output = [
"# Home Assistant Core, full dependency set\n",
@ -331,7 +346,7 @@ def requirements_all_output(reqs):
return "".join(output)
def requirements_test_all_output(reqs):
def requirements_test_all_output(reqs: dict[str, list[str]]) -> str:
"""Generate output for test_requirements."""
output = [
"# Home Assistant tests, full dependency set\n",
@ -356,15 +371,18 @@ def requirements_test_all_output(reqs):
return "".join(output)
def requirements_pre_commit_output():
def requirements_pre_commit_output() -> str:
"""Generate output for pre-commit dependencies."""
source = ".pre-commit-config.yaml"
pre_commit_conf = load_yaml(source)
reqs = []
pre_commit_conf: dict[str, list[dict[str, Any]]]
pre_commit_conf = load_yaml(source) # type: ignore[assignment]
reqs: list[str] = []
hook: dict[str, Any]
for repo in (x for x in pre_commit_conf["repos"] if x.get("rev")):
rev: str = repo["rev"]
for hook in repo["hooks"]:
if hook["id"] not in IGNORE_PRE_COMMIT_HOOK_ID:
reqs.append(f"{hook['id']}=={repo['rev'].lstrip('v')}")
reqs.append(f"{hook['id']}=={rev.lstrip('v')}")
reqs.extend(x for x in hook.get("additional_dependencies", ()))
output = [
f"# Automatically generated "
@ -375,7 +393,7 @@ def requirements_pre_commit_output():
return "\n".join(output) + "\n"
def gather_constraints():
def gather_constraints() -> str:
"""Construct output for constraint file."""
return (
"\n".join(
@ -392,7 +410,7 @@ def gather_constraints():
)
def diff_file(filename, content):
def diff_file(filename: str, content: str) -> list[str]:
"""Diff a file."""
return list(
difflib.context_diff(
@ -404,7 +422,7 @@ def diff_file(filename, content):
)
def main(validate):
def main(validate: bool) -> int:
"""Run the script."""
if not os.path.isfile("requirements_all.txt"):
print("Run this from HA root dir")
@ -415,7 +433,7 @@ def main(validate):
if data is None:
return 1
reqs_file = requirements_output(data)
reqs_file = requirements_output()
reqs_all_file = requirements_all_output(data)
reqs_test_all_file = requirements_test_all_output(data)
reqs_pre_commit_file = requirements_pre_commit_output()