Skip to content

Commit

Permalink
Add UsePrimitiveTypes rule
Browse files Browse the repository at this point in the history
  • Loading branch information
Anuar Navarro Hawach authored and amyreese committed Nov 1, 2024
1 parent 8430224 commit c8b3bbd
Show file tree
Hide file tree
Showing 9 changed files with 175 additions and 14 deletions.
2 changes: 1 addition & 1 deletion src/fixit/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def _fixit_file_wrapper(
autofix: bool = False,
options: Optional[Options] = None,
metrics_hook: Optional[MetricsHook] = None,
) -> List[Result]:
) -> list[Result]:
"""
Wrapper because generators can't be pickled or used directly via multiprocessing
TODO: replace this with some sort of queue or whatever
Expand Down
2 changes: 1 addition & 1 deletion src/fixit/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@


def splash(
visited: Set[Path], dirty: Set[Path], autofixes: int = 0, fixed: int = 0
visited: set[Path], dirty: set[Path], autofixes: int = 0, fixed: int = 0
) -> None:
def f(v: int) -> str:
return "file" if v == 1 else "files"
Expand Down
12 changes: 6 additions & 6 deletions src/fixit/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __init__(self, msg: str, rule: QualifiedRule):
super().__init__(msg)
self.rule = rule

def __reduce__(self) -> Tuple[Type[RuntimeError], Any]:
def __reduce__(self) -> tuple[Type[RuntimeError], Any]:
return type(self), (*self.args, self.rule)


Expand Down Expand Up @@ -174,7 +174,7 @@ def find_rules(rule: QualifiedRule) -> Iterable[Type[LintRule]]:
raise CollectionError(f"could not import rule(s) {rule}", rule) from e


def walk_module(module: ModuleType) -> Dict[str, Type[LintRule]]:
def walk_module(module: ModuleType) -> dict[str, Type[LintRule]]:
"""
Given a module object, return a mapping of all rule names to classes.
Expand Down Expand Up @@ -272,7 +272,7 @@ def collect_rules(
return materialized_rules


def locate_configs(path: Path, root: Optional[Path] = None) -> List[Path]:
def locate_configs(path: Path, root: Optional[Path] = None) -> list[Path]:
"""
Given a file path, locate all relevant config files in priority order.
Expand Down Expand Up @@ -307,7 +307,7 @@ def locate_configs(path: Path, root: Optional[Path] = None) -> List[Path]:
return results


def read_configs(paths: List[Path]) -> List[RawConfig]:
def read_configs(paths: List[Path]) -> list[RawConfig]:
"""
Read config data for each path given, and return their raw toml config values.
Expand Down Expand Up @@ -400,7 +400,7 @@ def parse_rule(


def merge_configs(
path: Path, raw_configs: List[RawConfig], root: Optional[Path] = None
path: Path, raw_configs: list[RawConfig], root: Optional[Path] = None
) -> Config:
"""
Given multiple raw configs, merge them in priority order.
Expand Down Expand Up @@ -594,7 +594,7 @@ def generate_config(
return config


def validate_config(path: Path) -> List[str]:
def validate_config(path: Path) -> list[str]:
"""
Validate the config provided. The provided path is expected to be a valid toml
config file. Any exception found while parsing or importing will be added to a list
Expand Down
2 changes: 1 addition & 1 deletion src/fixit/rules/chained_instance_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def unwrap(self, node: cst.BaseExpression) -> Iterator[cst.BaseExpression]:

def collect_targets(
self, stack: Tuple[cst.BaseExpression, ...]
) -> Tuple[
) -> tuple[
List[cst.BaseExpression], Dict[cst.BaseExpression, List[cst.BaseExpression]]
]:
targets: Dict[cst.BaseExpression, List[cst.BaseExpression]] = {}
Expand Down
2 changes: 1 addition & 1 deletion src/fixit/rules/cls_in_classmethod.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

class _RenameTransformer(cst.CSTTransformer):
def __init__(
self, names: List[Union[cst.Name, cst.BaseString, cst.Attribute]], new_name: str
self, names: list[Union[cst.Name, cst.BaseString, cst.Attribute]], new_name: str
) -> None:
self.names = names
self.new_name = new_name
Expand Down
2 changes: 1 addition & 1 deletion src/fixit/rules/no_namedtuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def leave_ClassDef(self, original_node: cst.ClassDef) -> None:

def partition_bases(
self, original_bases: Sequence[cst.Arg]
) -> Tuple[Optional[cst.Arg], List[cst.Arg]]:
) -> tuple[Optional[cst.Arg], List[cst.Arg]]:
# Returns a tuple of NamedTuple base object if it exists, and a list of non-NamedTuple bases
namedtuple_base: Optional[cst.Arg] = None
new_bases: List[cst.Arg] = []
Expand Down
161 changes: 161 additions & 0 deletions src/fixit/rules/use_primitive_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Set

import libcst

from fixit import Invalid, LintRule, Valid


REPLACE_TYPING_TYPE_ANNOTATION: str = (
"Use lowercase primitive type {primitive_type}"
+ "instead of {typing_type} (See [PEP 585 – Type Hinting Generics In Standard Collections](https://peps.python.org/pep-0585/#forward-compatibility))"
)

CUSTOM_TYPES_TO_REPLACE: Set[str] = {"Dict", "List", "Set", "Tuple"}


class UsePrimitiveTypes(LintRule):
"""
Enforces the use of primitive types instead of those in the ``typing`` module ()
since they are available on and ahead of Python ``3.10``.
"""

PYTHON_VERSION = ">= 3.10"

VALID = [
Valid(
"""
def foo() -> list:
pass
""",
),
Valid(
"""
def bar(x: set) -> None:
pass
""",
),
Valid(
"""
def baz(y: tuple) -> None:
pass
""",
),
Valid(
"""
def qux(z: dict) -> None:
pass
""",
),
]

INVALID = [
Invalid(
"""
def foo() -> List[int]:
pass
""",
expected_replacement="""
def foo() -> list[int]:
pass
""",
),
Invalid(
"""
def bar(x: Set[str]) -> None:
pass
""",
expected_replacement="""
def bar(x: set[str]) -> None:
pass
""",
),
Invalid(
"""
def baz(y: Tuple[int, str]) -> None:
pass
""",
expected_replacement="""
def baz(y: tuple[int, str]) -> None:
pass
""",
),
Invalid(
"""
def qux(z: Dict[str, int]) -> None:
pass
""",
expected_replacement="""
def qux(z: dict[str, int]) -> None:
pass
""",
),
]

def __init__(self) -> None:
super().__init__()
self.annotation_counter: int = 0

def visit_Annotation(self, node: libcst.Annotation) -> None:
self.annotation_counter += 1

def leave_Annotation(self, original_node: libcst.Annotation) -> None:
self.annotation_counter -= 1

def visit_FunctionDef(self, node: libcst.FunctionDef) -> None:
# Check return type
if isinstance(node.returns, libcst.Annotation):
if isinstance(node.returns.annotation, libcst.Subscript):
base_type = node.returns.annotation.value
if (
isinstance(base_type, libcst.Name)
and base_type.value in CUSTOM_TYPES_TO_REPLACE
):
new_base_type = base_type.with_changes(
value=base_type.value.lower()
)
new_annotation = node.returns.annotation.with_changes(
value=new_base_type
)
new_returns = node.returns.with_changes(annotation=new_annotation)
new_node = node.with_changes(returns=new_returns)
self.report(
node,
REPLACE_TYPING_TYPE_ANNOTATION.format(
primitive_type=base_type.value.lower(),
typing_type=base_type.value,
),
replacement=new_node,
)

# Check parameter types
for param in node.params.params:
if isinstance(param.annotation, libcst.Annotation):
if isinstance(param.annotation.annotation, libcst.Subscript):
base_type = param.annotation.annotation.value
if (
isinstance(base_type, libcst.Name)
and base_type.value in CUSTOM_TYPES_TO_REPLACE
):
new_base_type = base_type.with_changes(
value=base_type.value.lower()
)
new_annotation = param.annotation.annotation.with_changes(
value=new_base_type
)
new_param_annotation = param.annotation.with_changes(
annotation=new_annotation
)
new_param = param.with_changes(annotation=new_param_annotation)
self.report(
param,
REPLACE_TYPING_TYPE_ANNOTATION.format(
primitive_type=base_type.value.lower(),
typing_type=base_type.value,
),
replacement=new_param,
)
4 changes: 2 additions & 2 deletions src/fixit/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def gen_all_test_methods(rules: Collection[LintRule]) -> Sequence[TestCasePrecur

def generate_lint_rule_test_cases(
rules: Collection[LintRule],
) -> List[Type[unittest.TestCase]]:
) -> list[Type[unittest.TestCase]]:
test_case_classes: List[Type[unittest.TestCase]] = []
for test_case in gen_all_test_methods(rules):
rule_name = type(test_case.rule).__name__
Expand All @@ -191,7 +191,7 @@ def test_method(


def add_lint_rule_tests_to_module(
module_attrs: Dict[str, Any], rules: Collection[LintRule]
module_attrs: dict[str, Any], rules: Collection[LintRule]
) -> None:
"""
Generates classes inheriting from `unittest.TestCase` from the data available in `rules` and adds these to module_attrs.
Expand Down
2 changes: 1 addition & 1 deletion src/fixit/tests/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ def test_collect_rules(self) -> None:
UseTypesFromTyping.TAGS = {"typing"}
NoNamedTuple.TAGS = {"typing", "tuples"}

def collect_types(cfg: Config) -> List[Type[LintRule]]:
def collect_types(cfg: Config) -> list[Type[LintRule]]:
return sorted([type(rule) for rule in config.collect_rules(cfg)], key=str)

with self.subTest("everything"):
Expand Down

0 comments on commit c8b3bbd

Please sign in to comment.