Add ability to mark type hints as compulsory on specific functions (#139730)

pull/139784/head
epenet 2025-05-19 10:27:07 +02:00 committed by GitHub
parent 5f2425f421
commit 2bb0843c30
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 33 additions and 15 deletions

View File

@ -50,6 +50,9 @@ class TypeHintMatch:
kwargs_type: str | None = None kwargs_type: str | None = None
"""kwargs_type is for the special case `**kwargs`""" """kwargs_type is for the special case `**kwargs`"""
has_async_counterpart: bool = False has_async_counterpart: bool = False
"""`function_name` and `async_function_name` share arguments and return type"""
mandatory: bool = False
"""bypass ignore_missing_annotations"""
def need_to_check_function(self, node: nodes.FunctionDef) -> bool: def need_to_check_function(self, node: nodes.FunctionDef) -> bool:
"""Confirm if function should be checked.""" """Confirm if function should be checked."""
@ -184,6 +187,7 @@ _FUNCTION_MATCH: dict[str, list[TypeHintMatch]] = {
}, },
return_type="bool", return_type="bool",
has_async_counterpart=True, has_async_counterpart=True,
mandatory=True,
), ),
TypeHintMatch( TypeHintMatch(
function_name="async_setup_entry", function_name="async_setup_entry",
@ -192,6 +196,7 @@ _FUNCTION_MATCH: dict[str, list[TypeHintMatch]] = {
1: "ConfigEntry", 1: "ConfigEntry",
}, },
return_type="bool", return_type="bool",
mandatory=True,
), ),
TypeHintMatch( TypeHintMatch(
function_name="async_remove_entry", function_name="async_remove_entry",
@ -200,6 +205,7 @@ _FUNCTION_MATCH: dict[str, list[TypeHintMatch]] = {
1: "ConfigEntry", 1: "ConfigEntry",
}, },
return_type=None, return_type=None,
mandatory=True,
), ),
TypeHintMatch( TypeHintMatch(
function_name="async_unload_entry", function_name="async_unload_entry",
@ -208,6 +214,7 @@ _FUNCTION_MATCH: dict[str, list[TypeHintMatch]] = {
1: "ConfigEntry", 1: "ConfigEntry",
}, },
return_type="bool", return_type="bool",
mandatory=True,
), ),
TypeHintMatch( TypeHintMatch(
function_name="async_migrate_entry", function_name="async_migrate_entry",
@ -216,6 +223,7 @@ _FUNCTION_MATCH: dict[str, list[TypeHintMatch]] = {
1: "ConfigEntry", 1: "ConfigEntry",
}, },
return_type="bool", return_type="bool",
mandatory=True,
), ),
TypeHintMatch( TypeHintMatch(
function_name="async_remove_config_entry_device", function_name="async_remove_config_entry_device",
@ -225,6 +233,7 @@ _FUNCTION_MATCH: dict[str, list[TypeHintMatch]] = {
2: "DeviceEntry", 2: "DeviceEntry",
}, },
return_type="bool", return_type="bool",
mandatory=True,
), ),
TypeHintMatch( TypeHintMatch(
function_name="async_reset_platform", function_name="async_reset_platform",
@ -233,6 +242,7 @@ _FUNCTION_MATCH: dict[str, list[TypeHintMatch]] = {
1: "str", 1: "str",
}, },
return_type=None, return_type=None,
mandatory=True,
), ),
], ],
"__any_platform__": [ "__any_platform__": [
@ -246,6 +256,7 @@ _FUNCTION_MATCH: dict[str, list[TypeHintMatch]] = {
}, },
return_type=None, return_type=None,
has_async_counterpart=True, has_async_counterpart=True,
mandatory=True,
), ),
TypeHintMatch( TypeHintMatch(
function_name="async_setup_entry", function_name="async_setup_entry",
@ -255,6 +266,7 @@ _FUNCTION_MATCH: dict[str, list[TypeHintMatch]] = {
2: "AddConfigEntryEntitiesCallback", 2: "AddConfigEntryEntitiesCallback",
}, },
return_type=None, return_type=None,
mandatory=True,
), ),
], ],
"application_credentials": [ "application_credentials": [
@ -3195,8 +3207,11 @@ class HassTypeHintChecker(BaseChecker):
self._class_matchers.reverse() self._class_matchers.reverse()
def _ignore_function( def _ignore_function_match(
self, node: nodes.FunctionDef, annotations: list[nodes.NodeNG | None] self,
node: nodes.FunctionDef,
annotations: list[nodes.NodeNG | None],
match: TypeHintMatch,
) -> bool: ) -> bool:
"""Check if we can skip the function validation.""" """Check if we can skip the function validation."""
return ( return (
@ -3204,6 +3219,8 @@ class HassTypeHintChecker(BaseChecker):
not self._in_test_module not self._in_test_module
# some modules have checks forced # some modules have checks forced
and self._module_platform not in _FORCE_ANNOTATION_PLATFORMS and self._module_platform not in _FORCE_ANNOTATION_PLATFORMS
# some matches have checks forced
and not match.mandatory
# other modules are only checked ignore_missing_annotations # other modules are only checked ignore_missing_annotations
and self.linter.config.ignore_missing_annotations and self.linter.config.ignore_missing_annotations
and node.returns is None and node.returns is None
@ -3246,7 +3263,7 @@ class HassTypeHintChecker(BaseChecker):
continue continue
annotations = _get_all_annotations(function_node) annotations = _get_all_annotations(function_node)
if self._ignore_function(function_node, annotations): if self._ignore_function_match(function_node, annotations, match):
continue continue
self._check_function(function_node, match, annotations) self._check_function(function_node, match, annotations)
@ -3255,8 +3272,6 @@ class HassTypeHintChecker(BaseChecker):
def visit_functiondef(self, node: nodes.FunctionDef) -> None: def visit_functiondef(self, node: nodes.FunctionDef) -> None:
"""Apply relevant type hint checks on a FunctionDef node.""" """Apply relevant type hint checks on a FunctionDef node."""
annotations = _get_all_annotations(node) annotations = _get_all_annotations(node)
if self._ignore_function(node, annotations):
return
# Check method or function matchers. # Check method or function matchers.
if node.is_method(): if node.is_method():
@ -3277,14 +3292,15 @@ class HassTypeHintChecker(BaseChecker):
matchers = self._function_matchers matchers = self._function_matchers
# Check that common arguments are correctly typed. # Check that common arguments are correctly typed.
for arg_name, expected_type in _COMMON_ARGUMENTS.items(): if not self.linter.config.ignore_missing_annotations:
arg_node, annotation = _get_named_annotation(node, arg_name) for arg_name, expected_type in _COMMON_ARGUMENTS.items():
if arg_node and not _is_valid_type(expected_type, annotation): arg_node, annotation = _get_named_annotation(node, arg_name)
self.add_message( if arg_node and not _is_valid_type(expected_type, annotation):
"hass-argument-type", self.add_message(
node=arg_node, "hass-argument-type",
args=(arg_name, expected_type, node.name), node=arg_node,
) args=(arg_name, expected_type, node.name),
)
for match in matchers: for match in matchers:
if not match.need_to_check_function(node): if not match.need_to_check_function(node):
@ -3299,6 +3315,8 @@ class HassTypeHintChecker(BaseChecker):
match: TypeHintMatch, match: TypeHintMatch,
annotations: list[nodes.NodeNG | None], annotations: list[nodes.NodeNG | None],
) -> None: ) -> None:
if self._ignore_function_match(node, annotations, match):
return
# Check that all positional arguments are correctly annotated. # Check that all positional arguments are correctly annotated.
if match.arg_types: if match.arg_types:
for key, expected_type in match.arg_types.items(): for key, expected_type in match.arg_types.items():

View File

@ -99,7 +99,7 @@ def test_regex_a_or_b(
"code", "code",
[ [
""" """
async def setup( #@ async def async_turn_on( #@
arg1, arg2 arg1, arg2
): ):
pass pass
@ -115,7 +115,7 @@ def test_ignore_no_annotations(
func_node = astroid.extract_node( func_node = astroid.extract_node(
code, code,
"homeassistant.components.pylint_test", "homeassistant.components.pylint_test.light",
) )
type_hint_checker.visit_module(func_node.parent) type_hint_checker.visit_module(func_node.parent)