Skip to content

Commit 15c33fb

Browse files
committed
feat: support keyword arguments in marker expressions
Fixes #12281
1 parent e8fa8dd commit 15c33fb

4 files changed

Lines changed: 304 additions & 17 deletions

File tree

src/_pytest/mark/__init__.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
import collections
56
import dataclasses
67
from typing import AbstractSet
78
from typing import Collection
@@ -181,7 +182,9 @@ def from_item(cls, item: Item) -> KeywordMatcher:
181182

182183
return cls(mapped_names)
183184

184-
def __call__(self, subname: str) -> bool:
185+
def __call__(self, subname: str, /, **kwargs: object) -> bool:
186+
if kwargs:
187+
raise UsageError("Keyword expressions do not support call parameters.")
185188
subname = subname.lower()
186189
names = (name.lower() for name in self._names)
187190

@@ -211,24 +214,41 @@ def deselect_by_keyword(items: list[Item], config: Config) -> None:
211214
items[:] = remaining
212215

213216

217+
NOT_NONE_SENTINEL = object()
218+
219+
214220
@dataclasses.dataclass
215221
class MarkMatcher:
216222
"""A matcher for markers which are present.
217223
218224
Tries to match on any marker names, attached to the given colitem.
219225
"""
220226

221-
__slots__ = ("own_mark_names",)
227+
__slots__ = ("own_mark_name_mapping",)
222228

223-
own_mark_names: AbstractSet[str]
229+
own_mark_name_mapping: dict[str, list[Mark]]
224230

225231
@classmethod
226232
def from_item(cls, item: Item) -> MarkMatcher:
227-
mark_names = {mark.name for mark in item.iter_markers()}
228-
return cls(mark_names)
233+
mark_name_mapping = collections.defaultdict(list)
234+
for mark in item.iter_markers():
235+
mark_name_mapping[mark.name].append(mark)
236+
return cls(mark_name_mapping)
237+
238+
def __call__(self, name: str, /, **kwargs: object) -> bool:
239+
if not (matches := self.own_mark_name_mapping.get(name, [])):
240+
return False
241+
242+
if not kwargs:
243+
return True
229244

230-
def __call__(self, name: str) -> bool:
231-
return name in self.own_mark_names
245+
for mark in matches:
246+
if all(
247+
mark.kwargs.get(k, NOT_NONE_SENTINEL) == v for k, v in kwargs.items()
248+
):
249+
return True
250+
251+
return False
232252

233253

234254
def deselect_by_mark(items: list[Item], config: Config) -> None:

src/_pytest/mark/expression.py

Lines changed: 101 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
expression: expr? EOF
66
expr: and_expr ('or' and_expr)*
77
and_expr: not_expr ('and' not_expr)*
8-
not_expr: 'not' not_expr | '(' expr ')' | ident
8+
not_expr: 'not' not_expr | '(' expr ')' | ident ( '(' name '=' value ( ', ' name '=' value )* ')')*
9+
910
ident: (\w|:|\+|-|\.|\[|\]|\\|/)+
1011
1112
The semantics are:
@@ -20,12 +21,13 @@
2021
import ast
2122
import dataclasses
2223
import enum
24+
import keyword
2325
import re
2426
import types
25-
from typing import Callable
2627
from typing import Iterator
2728
from typing import Mapping
2829
from typing import NoReturn
30+
from typing import Protocol
2931
from typing import Sequence
3032

3133

@@ -43,6 +45,9 @@ class TokenType(enum.Enum):
4345
NOT = "not"
4446
IDENT = "identifier"
4547
EOF = "end of input"
48+
EQUAL = "="
49+
STRING = "str"
50+
COMMA = ","
4651

4752

4853
@dataclasses.dataclass(frozen=True)
@@ -86,6 +91,27 @@ def lex(self, input: str) -> Iterator[Token]:
8691
elif input[pos] == ")":
8792
yield Token(TokenType.RPAREN, ")", pos)
8893
pos += 1
94+
elif input[pos] == "=":
95+
yield Token(TokenType.EQUAL, "=", pos)
96+
pos += 1
97+
elif input[pos] == ",":
98+
yield Token(TokenType.COMMA, ",", pos)
99+
pos += 1
100+
elif (quote_char := input[pos]) == "'" or input[pos] == '"':
101+
quote_position = input[pos + 1 :].find(quote_char)
102+
if quote_position == -1:
103+
raise ParseError(
104+
pos + 1,
105+
f'closing quote "{quote_char}" is missing',
106+
)
107+
value = input[pos : pos + 2 + quote_position]
108+
if "\\" in value:
109+
raise ParseError(
110+
pos + 1,
111+
"escaping not supported in marker expression",
112+
)
113+
yield Token(TokenType.STRING, value, pos)
114+
pos += len(value)
89115
else:
90116
match = re.match(r"(:?\w|:|\+|-|\.|\[|\]|\\|/)+", input[pos:])
91117
if match:
@@ -166,18 +192,84 @@ def not_expr(s: Scanner) -> ast.expr:
166192
return ret
167193
ident = s.accept(TokenType.IDENT)
168194
if ident:
169-
return ast.Name(IDENT_PREFIX + ident.value, ast.Load())
195+
name = ast.Name(IDENT_PREFIX + ident.value, ast.Load())
196+
if s.accept(TokenType.LPAREN):
197+
ret = ast.Call(func=name, args=[], keywords=all_kwargs(s))
198+
s.accept(TokenType.RPAREN, reject=True)
199+
else:
200+
ret = name
201+
return ret
202+
170203
s.reject((TokenType.NOT, TokenType.LPAREN, TokenType.IDENT))
171204

172205

173-
class MatcherAdapter(Mapping[str, bool]):
206+
BUILTIN_MATCHERS = {"True": True, "False": False, "None": None}
207+
208+
209+
def single_kwarg(s: Scanner) -> ast.keyword:
210+
keyword_name = s.accept(TokenType.IDENT, reject=True)
211+
assert keyword_name is not None # for mypy
212+
if not keyword_name.value.isidentifier() or keyword.iskeyword(keyword_name.value):
213+
raise ParseError(
214+
keyword_name.pos + 1,
215+
f'unexpected character/s "{keyword_name.value}"',
216+
)
217+
s.accept(TokenType.EQUAL, reject=True)
218+
219+
if value_token := s.accept(TokenType.STRING):
220+
value: str | int | bool | None = value_token.value[1:-1] # strip quotes
221+
else:
222+
value_token = s.accept(TokenType.IDENT, reject=True)
223+
assert value_token is not None # for mypy
224+
if (
225+
(number := value_token.value).isdigit()
226+
or number.startswith("-")
227+
and number[1:].isdigit()
228+
):
229+
value = int(number)
230+
elif value_token.value in BUILTIN_MATCHERS:
231+
value = BUILTIN_MATCHERS[value_token.value]
232+
else:
233+
raise ParseError(
234+
value_token.pos + 1,
235+
f'unexpected character/s "{value_token.value}"',
236+
)
237+
238+
ret = ast.keyword(keyword_name.value, ast.Constant(value))
239+
return ret
240+
241+
242+
def all_kwargs(s: Scanner) -> list[ast.keyword]:
243+
ret = [single_kwarg(s)]
244+
while s.accept(TokenType.COMMA):
245+
ret.append(single_kwarg(s))
246+
return ret
247+
248+
249+
class MatcherCall(Protocol):
250+
def __call__(self, name: str, /, **kwargs: object) -> bool: ...
251+
252+
253+
@dataclasses.dataclass
254+
class MatcherNameAdapter:
255+
matcher: MatcherCall
256+
name: str
257+
258+
def __bool__(self) -> bool:
259+
return self.matcher(self.name)
260+
261+
def __call__(self, **kwargs: object) -> bool:
262+
return self.matcher(self.name, **kwargs)
263+
264+
265+
class MatcherAdapter(Mapping[str, MatcherNameAdapter]):
174266
"""Adapts a matcher function to a locals mapping as required by eval()."""
175267

176-
def __init__(self, matcher: Callable[[str], bool]) -> None:
268+
def __init__(self, matcher: MatcherCall) -> None:
177269
self.matcher = matcher
178270

179-
def __getitem__(self, key: str) -> bool:
180-
return self.matcher(key[len(IDENT_PREFIX) :])
271+
def __getitem__(self, key: str) -> MatcherNameAdapter:
272+
return MatcherNameAdapter(matcher=self.matcher, name=key[len(IDENT_PREFIX) :])
181273

182274
def __iter__(self) -> Iterator[str]:
183275
raise NotImplementedError()
@@ -211,7 +303,7 @@ def compile(self, input: str) -> Expression:
211303
)
212304
return Expression(code)
213305

214-
def evaluate(self, matcher: Callable[[str], bool]) -> bool:
306+
def evaluate(self, matcher: MatcherCall) -> bool:
215307
"""Evaluate the match expression.
216308
217309
:param matcher:
@@ -220,5 +312,5 @@ def evaluate(self, matcher: Callable[[str], bool]) -> bool:
220312
221313
:returns: Whether the expression matches or not.
222314
"""
223-
ret: bool = eval(self.code, {"__builtins__": {}}, MatcherAdapter(matcher))
315+
ret: bool = bool(eval(self.code, {"__builtins__": {}}, MatcherAdapter(matcher)))
224316
return ret

testing/test_mark.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,54 @@ def test_two():
233233
assert passed_str == expected_passed
234234

235235

236+
@pytest.mark.parametrize(
237+
("expr", "expected_passed"),
238+
[ # TODO: improve/sort out
239+
("car(color='red')", ["test_one"]),
240+
("car(color='red') or car(color='blue')", ["test_one", "test_two"]),
241+
("car and not car(temp=5)", ["test_one", "test_three"]),
242+
("car(temp=4)", ["test_one"]),
243+
("car(temp=4) or car(temp=5)", ["test_one", "test_two"]),
244+
("car(temp=4) and car(temp=5)", []),
245+
("car(temp=-5)", ["test_three"]),
246+
("car(ac=True)", ["test_one"]),
247+
("car(ac=False)", ["test_two"]),
248+
("car(ac=None)", ["test_three"]), # test NOT_NONE_SENTINEL
249+
],
250+
ids=str,
251+
)
252+
def test_mark_option_with_kwargs(
253+
expr: str, expected_passed: list[str | None], pytester: Pytester
254+
) -> None:
255+
pytester.makepyfile(
256+
"""
257+
import pytest
258+
@pytest.mark.car
259+
@pytest.mark.car(ac=True)
260+
@pytest.mark.car(temp=4)
261+
@pytest.mark.car(color="red")
262+
def test_one():
263+
pass
264+
@pytest.mark.car
265+
@pytest.mark.car(ac=False)
266+
@pytest.mark.car(temp=5)
267+
@pytest.mark.car(color="blue")
268+
def test_two():
269+
pass
270+
@pytest.mark.car
271+
@pytest.mark.car(ac=None)
272+
@pytest.mark.car(temp=-5)
273+
def test_three():
274+
pass
275+
276+
"""
277+
)
278+
rec = pytester.inline_run("-m", expr)
279+
passed, skipped, fail = rec.listoutcomes()
280+
passed_str = [x.nodeid.split("::")[-1] for x in passed]
281+
assert passed_str == expected_passed
282+
283+
236284
@pytest.mark.parametrize(
237285
("expr", "expected_passed"),
238286
[("interface", ["test_interface"]), ("not interface", ["test_nointer"])],

0 commit comments

Comments
 (0)