Skip to content

Commit

Permalink
Add codemod to fix variadic callable annotations (#1269)
Browse files Browse the repository at this point in the history
* add fix variadic callable codemod

* format
  • Loading branch information
yangdanny97 authored Jan 3, 2025
1 parent d269872 commit 776452f
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 0 deletions.
40 changes: 40 additions & 0 deletions libcst/codemod/commands/fix_variadic_callable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
# pyre-strict

import libcst as cst
import libcst.matchers as m
from libcst.codemod import VisitorBasedCodemodCommand
from libcst.metadata import QualifiedName, QualifiedNameProvider, QualifiedNameSource


class FixVariadicCallableCommmand(VisitorBasedCodemodCommand):
DESCRIPTION: str = (
"Fix incorrect variadic callable type annotations from `Callable[[...], T]` to `Callable[..., T]``"
)

METADATA_DEPENDENCIES = (QualifiedNameProvider,)

def leave_Subscript(
self, original_node: cst.Subscript, updated_node: cst.Subscript
) -> cst.BaseExpression:
if QualifiedNameProvider.has_name(
self,
original_node,
QualifiedName(name="typing.Callable", source=QualifiedNameSource.IMPORT),
):
node_matches = len(updated_node.slice) == 2 and m.matches(
updated_node.slice[0],
m.SubscriptElement(
slice=m.Index(value=m.List(elements=[m.Element(m.Ellipsis())]))
),
)

if node_matches:
slices = list(updated_node.slice)
slices[0] = cst.SubscriptElement(cst.Index(cst.Ellipsis()))
return updated_node.with_changes(slice=slices)
return updated_node
92 changes: 92 additions & 0 deletions libcst/codemod/commands/tests/test_fix_variadic_callable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
# pyre-strict

from libcst.codemod import CodemodTest
from libcst.codemod.commands.fix_variadic_callable import FixVariadicCallableCommmand


class TestFixVariadicCallableCommmand(CodemodTest):
TRANSFORM = FixVariadicCallableCommmand

def test_callable_typing(self) -> None:
before = """
from typing import Callable
x: Callable[[...], int] = ...
"""
after = """
from typing import Callable
x: Callable[..., int] = ...
"""
self.assertCodemod(before, after)

def test_callable_typing_alias(self) -> None:
before = """
import typing as t
x: t.Callable[[...], int] = ...
"""
after = """
import typing as t
x: t.Callable[..., int] = ...
"""
self.assertCodemod(before, after)

def test_callable_import_alias(self) -> None:
before = """
from typing import Callable as C
x: C[[...], int] = ...
"""
after = """
from typing import Callable as C
x: C[..., int] = ...
"""
self.assertCodemod(before, after)

def test_callable_with_optional(self) -> None:
before = """
from typing import Callable
def foo(bar: Optional[Callable[[...], int]]) -> Callable[[...], int]:
...
"""
after = """
from typing import Callable
def foo(bar: Optional[Callable[..., int]]) -> Callable[..., int]:
...
"""
self.assertCodemod(before, after)

def test_callable_with_arguments(self) -> None:
before = """
from typing import Callable
x: Callable[[int], int]
"""
after = """
from typing import Callable
x: Callable[[int], int]
"""
self.assertCodemod(before, after)

def test_callable_with_variadic_arguments(self) -> None:
before = """
from typing import Callable
x: Callable[[int, int, ...], int]
"""
after = """
from typing import Callable
x: Callable[[int, int, ...], int]
"""
self.assertCodemod(before, after)

def test_callable_no_arguments(self) -> None:
before = """
from typing import Callable
x: Callable
"""
after = """
from typing import Callable
x: Callable
"""
self.assertCodemod(before, after)

0 comments on commit 776452f

Please sign in to comment.