"""Capture and report three different conditions for boolean traps.

Currently detects three conditions:

* `BooleanTrapReason.TYPE_HINT`: positional args in function defs.
* `BooleanTrapReason.DEFAULT_VALUE`: boolean default values in function defs.
* `BooleanTrapReason.FUNCTION_CALL`: positional boolean args in function calls.

from __future__ import annotations

import ast
from enum import Enum
from typing import List
from typing import Tuple
from typing import Type
from typing import Union

    import importlib.metadata as importlib_metadata
except ImportError:
    import importlib_metadata  # type: ignore
from typing import Any
from typing import Generator

[docs]class BooleanTrapReason(Enum): """Defines boolean trap detection reasons. Enum values should contain a string template to be formatted """ TYPE_HINT = ( "FBT001 do not define boolean positional args." " Hint: in `def {func}(...)` , define `{arg}` as kw-only" ) """Function definition contains boolean type hint in positional arg.""" DEFAULT_VALUE = ( "FBT002 do not set boolean defaults for positional args." " Hint: in `def {func}(...)`, define `{arg}` as kw-only" ) """Function definition contains boolean default in positional arg.""" FUNCTION_CALL = ( "FBT003 do not use boolean positional args." " Hint: in `{func}(..)`, refactor positional arg #{position} to include its argument name" ) """Function call uses boolean positional arg."""
BooleanTrapRegistry = List[Tuple[int, int, Union[str, int], str, BooleanTrapReason]] CandidateGen = Generator[Tuple[ast.expr, int, int, str], None, None] LintError = Tuple[int, int, str, Type[Any]] LintErrorGen = Generator[LintError, None, None]
[docs]def positional_hints(func_node: ast.FunctionDef) -> CandidateGen: """Extract type hints from postiional arguments. Args: func_node: Function ast containing the arguments to search. Yields: Type annotations for arguments which could be used as positional. """ for arg in func_node.args.posonlyargs + func_node.args.args: hint = arg.annotation if hint is not None: yield hint, arg.lineno, arg.col_offset, arg.arg
[docs]def is_boolean_typehint(hint: ast.expr) -> bool: """Detect whether a type hint corresponds to a boolean. Args: hint: Type hint ast. Returns: True if the type hint corresponds to a boolean """ return any( ( isinstance(hint, ast.Name) and == "bool", isinstance(hint, ast.Constant) and hint.value == "bool", ) )
[docs]def default_values(func_node: ast.FunctionDef) -> CandidateGen: """Retrieve function args and respective default values. Args: func_node: The functions ast to parse its args. Args are iterated in reversed order because sizes could differ. "defaults is a list of default values for arguments that can be passed positionally. If there are fewer defaults, they correspond to the last n arguments." Yields: Candidates for args with default values. See Also: """ arg_defaults = zip(reversed(func_node.args.args), reversed(func_node.args.defaults)) for arg, default in arg_defaults: yield default, default.lineno, default.col_offset, arg.arg
[docs]def is_boolean_default(default: ast.expr) -> bool: """Detect whether a default argument value is a boolean. Args: default: Default parameter value ast. Returns: True if the default value is a boolean """ return isinstance(default, ast.Constant) and isinstance(default.value, bool)
[docs]class Visitor(ast.NodeVisitor): """Ast visitor to group violations in a registry. Currently visits only function definitions and calls, where boolean traps might be located. See `Visitor.visit_FunctionDef` and `Visitor.visit_Call`, respectively. """
[docs] def __init__(self) -> None: """Initialize the node visitor with an empty registry.""" self.registry: BooleanTrapRegistry = []
[docs] def visit_FunctionDef(self, func_node: ast.FunctionDef) -> None: # noqa: N802 """Capture boolean traps in function definitions. Args: func_node: the function ast to inspect. See Also * `BooleanTrapReason.TYPE_HINT`: `positional_hints`, `is_boolean_typehint` * `BooleanTrapReason.DEFAULT_VALUE`: `default_values`, `is_boolean_default` """ for hint, lineno, col, argname in positional_hints(func_node): if is_boolean_typehint(hint): self.registry.append( (lineno, col, argname,, BooleanTrapReason.TYPE_HINT) ) for default, lineno, col, argname in default_values(func_node): if is_boolean_default(default): self.registry.append( ( lineno, col, argname,, BooleanTrapReason.DEFAULT_VALUE, ) ) self.generic_visit(func_node)
[docs] def visit_Call(self, call: ast.Call) -> None: # noqa: N802 """Capture boolean traps in function calls. Args: call: the call ast to inspect. See Also `BooleanTrapReason.FUNCTION_CALL` """ for position, arg in enumerate(call.args): if ( isinstance(arg, ast.Constant) and isinstance(arg.value, bool) and isinstance(call.func, ast.Name) ): self.registry.append( ( call.lineno, call.col_offset, position,, BooleanTrapReason.FUNCTION_CALL, ) ) self.generic_visit(call)
[docs]class Plugin: # noqa: R0903 """Flake8 plugin implementation. Just reports the violations stored in `Visitor.registry` """ name = __name__ version = importlib_metadata.version(__name__)
[docs] def __init__(self, tree: ast.AST) -> None: """Initialize flake8 plugin by storing the ast tree at instance level. Args: tree: the whole module ast tree to parse. """ self._tree = tree
[docs] def run(self) -> LintErrorGen: """Report lint errors as parsed from `Visitor`. Yields: Boolean trap detections with their location and reason. """ visitor = Visitor() visitor.visit(self._tree) for line, col, arg_or_pos, func, reason in visitor.registry: if reason == BooleanTrapReason.FUNCTION_CALL: yield line, col, reason.value.format( position=arg_or_pos, func=func ), type(self) else: yield line, col, reason.value.format(arg=arg_or_pos, func=func), type( self )