Skip to content

Commit

Permalink
Add support for PEP-646 (#696)
Browse files Browse the repository at this point in the history
  • Loading branch information
zsol authored Jun 13, 2022
1 parent 380f045 commit ebe1851
Show file tree
Hide file tree
Showing 9 changed files with 269 additions and 9 deletions.
23 changes: 21 additions & 2 deletions libcst/_nodes/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -1438,10 +1438,29 @@ class Index(BaseSlice):
#: The index value itself.
value: BaseExpression

#: An optional string with an asterisk appearing before the name. This is
#: expanded into variable number of positional arguments. See PEP-646
star: Optional[Literal["*"]] = None

#: Whitespace after the ``star`` (if it exists), but before the ``value``.
whitespace_after_star: Optional[BaseParenthesizableWhitespace] = None

def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "Index":
return Index(value=visit_required(self, "value", self.value, visitor))
return Index(
star=self.star,
whitespace_after_star=visit_optional(
self, "whitespace_after_star", self.whitespace_after_star, visitor
),
value=visit_required(self, "value", self.value, visitor),
)

def _codegen_impl(self, state: CodegenState) -> None:
star = self.star
if star is not None:
state.add_token(star)
ws = self.whitespace_after_star
if ws is not None:
ws._codegen(state)
self.value._codegen(state)


Expand Down Expand Up @@ -2785,7 +2804,7 @@ def _codegen_impl(

@add_slots
@dataclass(frozen=True)
class StarredElement(BaseElement, _BaseParenthesizedNode):
class StarredElement(BaseElement, BaseExpression, _BaseParenthesizedNode):
"""
A starred ``*value`` element that expands to represent multiple values in a literal
:class:`List`, :class:`Tuple`, or :class:`Set`.
Expand Down
2 changes: 1 addition & 1 deletion libcst/_nodes/tests/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def __assert_visit_returns_identity(self, node: cst.CSTNode) -> None:
def assert_parses(
self,
code: str,
parser: Callable[[str], cst.BaseExpression],
parser: Callable[[str], cst.CSTNode],
expect_success: bool,
) -> None:
if not expect_success:
Expand Down
94 changes: 94 additions & 0 deletions libcst/_nodes/tests/test_funcdef.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,81 @@ class FunctionDefCreationTest(CSTNodeTest):
)
)
def test_valid(self, **kwargs: Any) -> None:
if not is_native() and kwargs.get("native_only", False):
self.skipTest("Disabled for native parser")
if "native_only" in kwargs:
kwargs.pop("native_only")
self.validate_node(**kwargs)

@data_provider(
(
# PEP 646
{
"node": cst.FunctionDef(
name=cst.Name(value="foo"),
params=cst.Parameters(
params=[],
star_arg=cst.Param(
star="*",
name=cst.Name("a"),
annotation=cst.Annotation(
cst.StarredElement(value=cst.Name("b")),
whitespace_before_indicator=cst.SimpleWhitespace(""),
),
),
),
body=cst.SimpleStatementSuite((cst.Pass(),)),
),
"parser": parse_statement,
"code": "def foo(*a: *b): pass\n",
},
{
"node": cst.FunctionDef(
name=cst.Name(value="foo"),
params=cst.Parameters(
params=[],
star_arg=cst.Param(
star="*",
name=cst.Name("a"),
annotation=cst.Annotation(
cst.StarredElement(
value=cst.Subscript(
value=cst.Name("tuple"),
slice=[
cst.SubscriptElement(
cst.Index(cst.Name("int")),
comma=cst.Comma(),
),
cst.SubscriptElement(
cst.Index(
value=cst.Name("Ts"),
star="*",
whitespace_after_star=cst.SimpleWhitespace(
""
),
),
comma=cst.Comma(),
),
cst.SubscriptElement(
cst.Index(cst.Ellipsis())
),
],
)
),
whitespace_before_indicator=cst.SimpleWhitespace(""),
),
),
),
body=cst.SimpleStatementSuite((cst.Pass(),)),
),
"parser": parse_statement,
"code": "def foo(*a: *tuple[int,*Ts,...]): pass\n",
},
)
)
def test_valid_native(self, **kwargs: Any) -> None:
if not is_native():
self.skipTest("Disabled for native parser")
self.validate_node(**kwargs)

@data_provider(
Expand Down Expand Up @@ -2045,3 +2120,22 @@ def test_versions(self, **kwargs: Any) -> None:
if is_native() and not kwargs.get("expect_success", True):
self.skipTest("parse errors are disabled for native parser")
self.assert_parses(**kwargs)

@data_provider(
(
{"code": "A[:*b]"},
{"code": "A[*b:]"},
{"code": "A[*b:*b]"},
{"code": "A[*(1:2)]"},
{"code": "A[*:]"},
{"code": "A[:*]"},
{"code": "A[**b]"},
{"code": "def f(x: *b): pass"},
{"code": "def f(**x: *b): pass"},
{"code": "x: *b"},
)
)
def test_parse_error(self, **kwargs: Any) -> None:
if not is_native():
self.skipTest("Skipped for non-native parser")
self.assert_parses(**kwargs, expect_success=False, parser=parse_statement)
18 changes: 17 additions & 1 deletion libcst/_typed_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2807,6 +2807,22 @@ def visit_Index_value(self, node: "Index") -> None:
def leave_Index_value(self, node: "Index") -> None:
pass

@mark_no_op
def visit_Index_star(self, node: "Index") -> None:
pass

@mark_no_op
def leave_Index_star(self, node: "Index") -> None:
pass

@mark_no_op
def visit_Index_whitespace_after_star(self, node: "Index") -> None:
pass

@mark_no_op
def leave_Index_whitespace_after_star(self, node: "Index") -> None:
pass

@mark_no_op
def visit_Integer(self, node: "Integer") -> Optional[bool]:
pass
Expand Down Expand Up @@ -7056,7 +7072,7 @@ def leave_StarredDictElement(
@mark_no_op
def leave_StarredElement(
self, original_node: "StarredElement", updated_node: "StarredElement"
) -> Union["BaseElement", FlattenSentinel["BaseElement"], RemovalSentinel]:
) -> "BaseExpression":
return updated_node

@mark_no_op
Expand Down
42 changes: 41 additions & 1 deletion libcst/matchers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7364,6 +7364,46 @@ class Index(BaseSlice, BaseMatcherNode):
OneOf[BaseExpressionMatchType],
AllOf[BaseExpressionMatchType],
] = DoNotCare()
star: Union[
Optional[Literal["*"]],
MetadataMatchType,
MatchIfTrue[Optional[Literal["*"]]],
DoNotCareSentinel,
OneOf[
Union[
Optional[Literal["*"]],
MetadataMatchType,
MatchIfTrue[Optional[Literal["*"]]],
]
],
AllOf[
Union[
Optional[Literal["*"]],
MetadataMatchType,
MatchIfTrue[Optional[Literal["*"]]],
]
],
] = DoNotCare()
whitespace_after_star: Union[
Optional["BaseParenthesizableWhitespace"],
MetadataMatchType,
MatchIfTrue[Optional[cst.BaseParenthesizableWhitespace]],
DoNotCareSentinel,
OneOf[
Union[
Optional["BaseParenthesizableWhitespace"],
MetadataMatchType,
MatchIfTrue[Optional[cst.BaseParenthesizableWhitespace]],
]
],
AllOf[
Union[
Optional["BaseParenthesizableWhitespace"],
MetadataMatchType,
MatchIfTrue[Optional[cst.BaseParenthesizableWhitespace]],
]
],
] = DoNotCare()
metadata: Union[
MetadataMatchType,
DoNotCareSentinel,
Expand Down Expand Up @@ -13644,7 +13684,7 @@ class StarredDictElement(BaseDictElement, BaseMatcherNode):


@dataclass(frozen=True, eq=False, unsafe_hash=False)
class StarredElement(BaseElement, BaseMatcherNode):
class StarredElement(BaseElement, BaseExpression, BaseMatcherNode):
value: Union[
BaseExpressionMatchType,
DoNotCareSentinel,
Expand Down
2 changes: 1 addition & 1 deletion libcst/matchers/_return_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@
SimpleWhitespace: Union[BaseParenthesizableWhitespace, MaybeSentinel],
Slice: BaseSlice,
StarredDictElement: Union[BaseDictElement, RemovalSentinel],
StarredElement: Union[BaseElement, RemovalSentinel],
StarredElement: BaseExpression,
Subscript: BaseExpression,
SubscriptElement: Union[SubscriptElement, RemovalSentinel],
Subtract: BaseBinaryOp,
Expand Down
27 changes: 25 additions & 2 deletions native/libcst/src/nodes/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1787,18 +1787,41 @@ pub enum BaseSlice<'a> {
#[cst_node]
pub struct Index<'a> {
pub value: Expression<'a>,
pub star: Option<&'a str>,
pub whitespace_after_star: Option<ParenthesizableWhitespace<'a>>,

pub(crate) star_tok: Option<TokenRef<'a>>,
}

impl<'r, 'a> Inflate<'a> for DeflatedIndex<'r, 'a> {
type Inflated = Index<'a>;
fn inflate(self, config: &Config<'a>) -> Result<Self::Inflated> {
fn inflate(mut self, config: &Config<'a>) -> Result<Self::Inflated> {
let (star, whitespace_after_star) = if let Some(star_tok) = self.star_tok.as_mut() {
(
Some(star_tok.string),
Some(parse_parenthesizable_whitespace(
config,
&mut star_tok.whitespace_after.borrow_mut(),
)?),
)
} else {
(None, None)
};
let value = self.value.inflate(config)?;
Ok(Self::Inflated { value })
Ok(Self::Inflated {
value,
star,
whitespace_after_star,
})
}
}

impl<'a> Codegen<'a> for Index<'a> {
fn codegen(&self, state: &mut CodegenState<'a>) {
if let Some(star) = self.star {
state.add_token(star);
}
self.whitespace_after_star.codegen(state);
self.value.codegen(state);
}
}
Expand Down
33 changes: 32 additions & 1 deletion native/libcst/src/parser/grammar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,10 @@ parser! {
StarEtc(Some(StarArg::Param(Box::new(
add_param_star(a, star)))), b, kw)
}
/ star:lit("*") a:param_no_default_star_annotation() b:param_maybe_default()* kw:kwds()? {
StarEtc(Some(StarArg::Param(Box::new(
add_param_star(a, star)))), b, kw)
}
/ lit("*") c:comma() b:param_maybe_default()+ kw:kwds()? {
StarEtc(Some(StarArg::Star(Box::new(ParamStar {comma:c }))), b, kw)
}
Expand All @@ -401,6 +405,10 @@ parser! {
= a:param() c:lit(",") { add_param_default(a, None, Some(c)) }
/ a:param() &lit(")") {a}

rule param_no_default_star_annotation() -> Param<'input, 'a>
= a:param_star_annotation() c:lit(",") { add_param_default(a, None, Some(c))}
/ a:param_star_annotation() &lit(")") {a}

rule param_with_default() -> Param<'input, 'a>
= a:param() def:default() c:lit(",") {
add_param_default(a, Some(def), Some(c))
Expand All @@ -422,11 +430,21 @@ parser! {
Param {name: n, annotation: a, ..Default::default() }
}

rule param_star_annotation() -> Param<'input, 'a>
= n:name() a:star_annotation() {
Param {name: n, annotation: Some(a), ..Default::default() }
}

rule annotation() -> Annotation<'input, 'a>
= col:lit(":") e:expression() {
make_annotation(col, e)
}

rule star_annotation() -> Annotation<'input, 'a>
= col:lit(":") e:star_expression() {
make_annotation(col, e)
}

rule default() -> (AssignEqual<'input, 'a>, Expression<'input, 'a>)
= eq:lit("=") ex:expression() {
(make_assign_equal(eq), ex)
Expand Down Expand Up @@ -983,6 +1001,7 @@ parser! {
rest:(c:lit(":") s:expression()? {(c, s)})? {
make_slice(l, col, u, rest)
}
/ e:starred_expression() { make_index_from_arg(e) }
/ v:expression() { make_index(v) }

rule atom() -> Expression<'input, 'a>
Expand Down Expand Up @@ -2412,7 +2431,19 @@ fn make_double_starred_element<'input, 'a>(
}

fn make_index<'input, 'a>(value: Expression<'input, 'a>) -> BaseSlice<'input, 'a> {
BaseSlice::Index(Box::new(Index { value }))
BaseSlice::Index(Box::new(Index {
value,
star: None,
star_tok: None,
}))
}

fn make_index_from_arg<'input, 'a>(arg: Arg<'input, 'a>) -> BaseSlice<'input, 'a> {
BaseSlice::Index(Box::new(Index {
value: arg.value,
star: Some(arg.star),
star_tok: arg.star_tok,
}))
}

fn make_colon<'input, 'a>(tok: TokenRef<'input, 'a>) -> Colon<'input, 'a> {
Expand Down
Loading

0 comments on commit ebe1851

Please sign in to comment.