diff --git a/libcst/codemod/visitors/_apply_type_annotations.py b/libcst/codemod/visitors/_apply_type_annotations.py index d047f72b6..39754af70 100644 --- a/libcst/codemod/visitors/_apply_type_annotations.py +++ b/libcst/codemod/visitors/_apply_type_annotations.py @@ -203,7 +203,7 @@ class ApplyTypeAnnotationsVisitor(ContextAwareTransformer): This is one of the transforms that is available automatically to you when running a codemod. To use it in this manner, import :class:`~libcst.codemod.visitors.ApplyTypeAnnotationsVisitor` and then call the static - :meth:`~libcst.codemod.visitors.ApplyTypeAnnotationsVisitor.add_stub_to_context` method, + :meth:`~libcst.codemod.visitors.ApplyTypeAnnotationsVisitor.store_stub_in_context` method, giving it the current context (found as ``self.context`` for all subclasses of :class:`~libcst.codemod.Codemod`), the stub module from which you wish to add annotations. @@ -211,7 +211,7 @@ class ApplyTypeAnnotationsVisitor(ContextAwareTransformer): stub_module = parse_module("x: int = ...") - ApplyTypeAnnotationsVisitor.add_stub_to_context(self.context, stub_module) + ApplyTypeAnnotationsVisitor.store_stub_in_context(self.context, stub_module) You can apply the type annotation using:: @@ -223,12 +223,19 @@ class ApplyTypeAnnotationsVisitor(ContextAwareTransformer): x: int = 1 If the function or attribute already has a type annotation, it will not be overwritten. + + To overwrite existing annotations when applying annotations from a stub, + use the keyword argument ``overwrite_existing_annotations=True`` when + constructing the codemod or when calling ``store_stub_in_context``. """ CONTEXT_KEY = "ApplyTypeAnnotationsVisitor" def __init__( - self, context: CodemodContext, annotations: Optional[Annotations] = None + self, + context: CodemodContext, + annotations: Optional[Annotations] = None, + overwrite_existing_annotations: bool = False, ) -> None: super().__init__(context) # Qualifier for storing the canonical name of the current function. @@ -236,20 +243,32 @@ def __init__( self.annotations: Annotations = annotations or Annotations() self.toplevel_annotations: Dict[str, cst.Annotation] = {} self.visited_classes: Set[str] = set() + self.overwrite_existing_annotations = overwrite_existing_annotations # We use this to determine the end of the import block so that we can # insert top-level annotations. self.import_statements: List[cst.ImportFrom] = [] @staticmethod - def add_stub_to_context(context: CodemodContext, stub: cst.Module) -> None: + def store_stub_in_context( + context: CodemodContext, + stub: cst.Module, + overwrite_existing_annotations: bool = False, + ) -> None: """ - Add a stub module to the :class:`~libcst.codemod.CodemodContext` so + Store a stub module in the :class:`~libcst.codemod.CodemodContext` so that type annotations from the stub can be applied in a later invocation of this class. + + If the ``overwrite_existing_annotations`` flag is ``True``, the + codemod will overwrite any existing annotations. + + If you call this function multiple times, only the last values of + ``stub`` and ``overwrite_existing_annotations`` will take effect. """ - context.scratch.setdefault(ApplyTypeAnnotationsVisitor.CONTEXT_KEY, []).append( - stub + context.scratch[ApplyTypeAnnotationsVisitor.CONTEXT_KEY] = ( + stub, + overwrite_existing_annotations, ) def transform_module_impl(self, tree: cst.Module) -> cst.Module: @@ -262,8 +281,14 @@ def transform_module_impl(self, tree: cst.Module) -> cst.Module: tree.visit(import_gatherer) existing_import_names = _get_import_names(import_gatherer.all_imports) - stubs = self.context.scratch.get(ApplyTypeAnnotationsVisitor.CONTEXT_KEY, []) - for stub in stubs: + context_contents = self.context.scratch.get( + ApplyTypeAnnotationsVisitor.CONTEXT_KEY + ) + if context_contents: + stub, overwrite_existing_annotations = context_contents + self.overwrite_existing_annotations = ( + self.overwrite_existing_annotations or overwrite_existing_annotations + ) visitor = TypeCollector(existing_import_names, self.context) stub.visit(visitor) self.annotations.function_annotations.update(visitor.function_annotations) @@ -339,7 +364,8 @@ def _update_parameters( self, annotations: FunctionAnnotation, updated_node: cst.FunctionDef ) -> cst.Parameters: # Update params and default params with annotations - # don't override existing annotations or default values + # Don't override existing annotations or default values unless asked + # to overwrite existing annotations. def update_annotation( parameters: Sequence[cst.Param], annotations: Sequence[cst.Param] ) -> List[cst.Param]: @@ -350,7 +376,9 @@ def update_annotation( parameter_annotations[parameter.name.value] = parameter.annotation for parameter in parameters: key = parameter.name.value - if key in parameter_annotations and not parameter.annotation: + if key in parameter_annotations and ( + self.overwrite_existing_annotations or not parameter.annotation + ): parameter = parameter.with_changes( annotation=parameter_annotations[key] ) @@ -409,8 +437,9 @@ def leave_FunctionDef( self.qualifier.pop() if key in self.annotations.function_annotations: function_annotation = self.annotations.function_annotations[key] - # Only add new annotation if one doesn't already exist - if not updated_node.returns: + # Only add new annotation if explicitly told to overwrite existing + # annotations or if one doesn't already exist. + if self.overwrite_existing_annotations or not updated_node.returns: updated_node = updated_node.with_changes( returns=function_annotation.returns ) diff --git a/libcst/codemod/visitors/tests/test_apply_type_annotations.py b/libcst/codemod/visitors/tests/test_apply_type_annotations.py index 7662b22df..e32348a72 100644 --- a/libcst/codemod/visitors/tests/test_apply_type_annotations.py +++ b/libcst/codemod/visitors/tests/test_apply_type_annotations.py @@ -608,7 +608,45 @@ def foo() -> typing.Sequence[int]: ) def test_annotate_functions(self, stub: str, before: str, after: str) -> None: context = CodemodContext() - ApplyTypeAnnotationsVisitor.add_stub_to_context( + ApplyTypeAnnotationsVisitor.store_stub_in_context( context, parse_module(textwrap.dedent(stub.rstrip())) ) self.assertCodemod(before, after, context_override=context) + + @data_provider( + ( + ( + """ + def fully_annotated_with_different_stub(a: bool, b: bool) -> str: ... + """, + """ + def fully_annotated_with_different_stub(a: int, b: str) -> bool: + return 'hello' + """, + """ + def fully_annotated_with_different_stub(a: bool, b: bool) -> str: + return 'hello' + """, + ), + ) + ) + def test_annotate_functions_with_existing_annotations( + self, stub: str, before: str, after: str + ) -> None: + context = CodemodContext() + ApplyTypeAnnotationsVisitor.store_stub_in_context( + context, parse_module(textwrap.dedent(stub.rstrip())) + ) + # Test setting the overwrite flag on the codemod instance. + self.assertCodemod( + before, after, context_override=context, overwrite_existing_annotations=True + ) + + # Test setting the flag when storing the stub in the context. + context = CodemodContext() + ApplyTypeAnnotationsVisitor.store_stub_in_context( + context, + parse_module(textwrap.dedent(stub.rstrip())), + overwrite_existing_annotations=True, + ) + self.assertCodemod(before, after, context_override=context)