From 4508b985364a685c34eebb691374a2c4a3e96908 Mon Sep 17 00:00:00 2001 From: JP-Ellis Date: Thu, 2 Nov 2023 08:29:10 +1100 Subject: [PATCH 1/2] fix(typing): improve decorator type hinting The type hinting for the most commonly used decorators were incomplete, resulting in decorated functions being obscured. This makes use of the special type variable `ParamSpec` which allows the type hinting a view into the parameters of a function. As ``ParamSpec` was introduced in Python 3.10, `ParamSpec` is imported from the `typing_extensions` module instead of the standard library. I have also taken the opportunity to fix other instances of `Callable` type hints missing their arguments. Signed-off-by: JP-Ellis --- src/pytest_bdd/plugin.py | 23 +++++++++++++++-------- src/pytest_bdd/reporting.py | 10 ++++++++-- src/pytest_bdd/scenario.py | 20 ++++++++++++-------- src/pytest_bdd/steps.py | 22 ++++++++++++---------- src/pytest_bdd/utils.py | 2 +- 5 files changed, 48 insertions(+), 29 deletions(-) diff --git a/src/pytest_bdd/plugin.py b/src/pytest_bdd/plugin.py index 486cdf8..ccee011 100644 --- a/src/pytest_bdd/plugin.py +++ b/src/pytest_bdd/plugin.py @@ -1,16 +1,15 @@ """Pytest plugin entry point. Used for any fixtures needed.""" from __future__ import annotations -from typing import TYPE_CHECKING, Callable, cast +from typing import TYPE_CHECKING, Any, Callable, Generator, TypeVar, cast import pytest +from typing_extensions import ParamSpec from . import cucumber_json, generation, gherkin_terminal_reporter, given, reporting, then, when from .utils import CONFIG_STACK if TYPE_CHECKING: - from typing import Any, Generator - from _pytest.config import Config, PytestPluginManager from _pytest.config.argparsing import Parser from _pytest.fixtures import FixtureRequest @@ -21,6 +20,10 @@ if TYPE_CHECKING: from .parser import Feature, Scenario, Step +P = ParamSpec("P") +T = TypeVar("T") + + def pytest_addhooks(pluginmanager: PytestPluginManager) -> None: """Register plugin hooks.""" from pytest_bdd import hooks @@ -93,7 +96,7 @@ def pytest_bdd_step_error( feature: Feature, scenario: Scenario, step: Step, - step_func: Callable, + step_func: Callable[..., Any], step_func_args: dict, exception: Exception, ) -> None: @@ -102,7 +105,11 @@ def pytest_bdd_step_error( @pytest.hookimpl(tryfirst=True) def pytest_bdd_before_step( - request: FixtureRequest, feature: Feature, scenario: Scenario, step: Step, step_func: Callable + request: FixtureRequest, + feature: Feature, + scenario: Scenario, + step: Step, + step_func: Callable[..., Any], ) -> None: reporting.before_step(request, feature, scenario, step, step_func) @@ -113,7 +120,7 @@ def pytest_bdd_after_step( feature: Feature, scenario: Scenario, step: Step, - step_func: Callable, + step_func: Callable[..., Any], step_func_args: dict[str, Any], ) -> None: reporting.after_step(request, feature, scenario, step, step_func, step_func_args) @@ -123,7 +130,7 @@ def pytest_cmdline_main(config: Config) -> int | None: return generation.cmdline_main(config) -def pytest_bdd_apply_tag(tag: str, function: Callable) -> Callable: +def pytest_bdd_apply_tag(tag: str, function: Callable[P, T]) -> Callable[P, T]: mark = getattr(pytest.mark, tag) marked = mark(function) - return cast(Callable, marked) + return cast(Callable[P, T], marked) diff --git a/src/pytest_bdd/reporting.py b/src/pytest_bdd/reporting.py index 26e1cb0..95254f6 100644 --- a/src/pytest_bdd/reporting.py +++ b/src/pytest_bdd/reporting.py @@ -155,7 +155,7 @@ def step_error( feature: Feature, scenario: Scenario, step: Step, - step_func: Callable, + step_func: Callable[..., Any], step_func_args: dict, exception: Exception, ) -> None: @@ -163,7 +163,13 @@ def step_error( request.node.__scenario_report__.fail() -def before_step(request: FixtureRequest, feature: Feature, scenario: Scenario, step: Step, step_func: Callable) -> None: +def before_step( + request: FixtureRequest, + feature: Feature, + scenario: Scenario, + step: Step, + step_func: Callable[..., Any], +) -> None: """Store step start time.""" request.node.__scenario_report__.add_step_report(StepReport(step=step)) diff --git a/src/pytest_bdd/scenario.py b/src/pytest_bdd/scenario.py index df7c029..7a231ef 100644 --- a/src/pytest_bdd/scenario.py +++ b/src/pytest_bdd/scenario.py @@ -16,11 +16,12 @@ import contextlib import logging import os import re -from typing import TYPE_CHECKING, Callable, Iterator, cast +from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, TypeVar, cast import pytest from _pytest.fixtures import FixtureDef, FixtureManager, FixtureRequest, call_fixture_func from _pytest.nodes import iterparentnodeids +from typing_extensions import ParamSpec from . import exceptions from .feature import get_feature, get_features @@ -28,12 +29,12 @@ from .steps import StepFunctionContext, get_step_fixture_name, inject_fixture from .utils import CONFIG_STACK, get_args, get_caller_module_locals, get_caller_module_path if TYPE_CHECKING: - from typing import Any, Iterable - from _pytest.mark.structures import ParameterSet from .parser import Feature, Scenario, ScenarioTemplate, Step +P = ParamSpec("P") +T = TypeVar("T") logger = logging.getLogger(__name__) @@ -197,14 +198,14 @@ def _execute_scenario(feature: Feature, scenario: Scenario, request: FixtureRequ def _get_scenario_decorator( feature: Feature, feature_name: str, templated_scenario: ScenarioTemplate, scenario_name: str -) -> Callable[[Callable], Callable]: +) -> Callable[[Callable[P, T]], Callable[P, T]]: # HACK: Ideally we would use `def decorator(fn)`, but we want to return a custom exception # when the decorator is misused. # Pytest inspect the signature to determine the required fixtures, and in that case it would look # for a fixture called "fn" that doesn't exist (if it exists then it's even worse). # It will error with a "fixture 'fn' not found" message instead. # We can avoid this hack by using a pytest hook and check for misuse instead. - def decorator(*args: Callable) -> Callable: + def decorator(*args: Callable[P, T]) -> Callable[P, T]: if not args: raise exceptions.ScenarioIsDecoratorOnly( "scenario function can only be used as a decorator. Refer to the documentation." @@ -236,7 +237,7 @@ def _get_scenario_decorator( scenario_wrapper.__doc__ = f"{feature_name}: {scenario_name}" scenario_wrapper.__scenario__ = templated_scenario - return cast(Callable, scenario_wrapper) + return cast(Callable[P, T], scenario_wrapper) return decorator @@ -254,8 +255,11 @@ def collect_example_parametrizations( def scenario( - feature_name: str, scenario_name: str, encoding: str = "utf-8", features_base_dir=None -) -> Callable[[Callable], Callable]: + feature_name: str, + scenario_name: str, + encoding: str = "utf-8", + features_base_dir: str | None = None, +) -> Callable[[Callable[P, T]], Callable[P, T]]: """Scenario decorator. :param str feature_name: Feature file name. Absolute or relative to the configured feature base path. diff --git a/src/pytest_bdd/steps.py b/src/pytest_bdd/steps.py index b3d8be6..83c54e4 100644 --- a/src/pytest_bdd/steps.py +++ b/src/pytest_bdd/steps.py @@ -43,13 +43,15 @@ from typing import Any, Callable, Iterable, Literal, TypeVar import pytest from _pytest.fixtures import FixtureDef, FixtureRequest +from typing_extensions import ParamSpec from .parser import Step from .parsers import StepParser, get_parser from .types import GIVEN, THEN, WHEN from .utils import get_caller_module_locals -TCallable = TypeVar("TCallable", bound=Callable[..., Any]) +P = ParamSpec("P") +T = TypeVar("T") @enum.unique @@ -74,10 +76,10 @@ def get_step_fixture_name(step: Step) -> str: def given( name: str | StepParser, - converters: dict[str, Callable] | None = None, + converters: dict[str, Callable[[Any], Any]] | None = None, target_fixture: str | None = None, stacklevel: int = 1, -) -> Callable: +) -> Callable[[Callable[P, T]], Callable[P, T]]: """Given step decorator. :param name: Step name or a parser object. @@ -93,10 +95,10 @@ def given( def when( name: str | StepParser, - converters: dict[str, Callable] | None = None, + converters: dict[str, Callable[[Any], Any]] | None = None, target_fixture: str | None = None, stacklevel: int = 1, -) -> Callable: +) -> Callable[[Callable[P, T]], Callable[P, T]]: """When step decorator. :param name: Step name or a parser object. @@ -112,10 +114,10 @@ def when( def then( name: str | StepParser, - converters: dict[str, Callable] | None = None, + converters: dict[str, Callable[[Any], Any]] | None = None, target_fixture: str | None = None, stacklevel: int = 1, -) -> Callable: +) -> Callable[[Callable[P, T]], Callable[P, T]]: """Then step decorator. :param name: Step name or a parser object. @@ -132,10 +134,10 @@ def then( def step( name: str | StepParser, type_: Literal["given", "when", "then"] | None = None, - converters: dict[str, Callable] | None = None, + converters: dict[str, Callable[[Any], Any]] | None = None, target_fixture: str | None = None, stacklevel: int = 1, -) -> Callable[[TCallable], TCallable]: +) -> Callable[[Callable[P, T]], Callable[P, T]]: """Generic step decorator. :param name: Step name as in the feature file. @@ -155,7 +157,7 @@ def step( if converters is None: converters = {} - def decorator(func: TCallable) -> TCallable: + def decorator(func: Callable[P, T]) -> Callable[P, T]: parser = get_parser(name) context = StepFunctionContext( diff --git a/src/pytest_bdd/utils.py b/src/pytest_bdd/utils.py index 3554078..eb243e5 100644 --- a/src/pytest_bdd/utils.py +++ b/src/pytest_bdd/utils.py @@ -19,7 +19,7 @@ T = TypeVar("T") CONFIG_STACK: list[Config] = [] -def get_args(func: Callable) -> list[str]: +def get_args(func: Callable[..., Any]) -> list[str]: """Get a list of argument names for a function. :param func: The function to inspect. From 9c60589fb9324ef79a8416ae5cc3cbc84f78a34a Mon Sep 17 00:00:00 2001 From: Alessio Bogon <778703+youtux@users.noreply.github.com> Date: Sat, 2 Dec 2023 21:57:07 +0100 Subject: [PATCH 2/2] fix type for `StepFunctionContext.converters` --- src/pytest_bdd/steps.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/pytest_bdd/steps.py b/src/pytest_bdd/steps.py index 83c54e4..98f116d 100644 --- a/src/pytest_bdd/steps.py +++ b/src/pytest_bdd/steps.py @@ -65,7 +65,7 @@ class StepFunctionContext: type: Literal["given", "when", "then"] | None step_func: Callable[..., Any] parser: StepParser - converters: dict[str, Callable[..., Any]] = field(default_factory=dict) + converters: dict[str, Callable[[str], Any]] = field(default_factory=dict) target_fixture: str | None = None @@ -76,7 +76,7 @@ def get_step_fixture_name(step: Step) -> str: def given( name: str | StepParser, - converters: dict[str, Callable[[Any], Any]] | None = None, + converters: dict[str, Callable[[str], Any]] | None = None, target_fixture: str | None = None, stacklevel: int = 1, ) -> Callable[[Callable[P, T]], Callable[P, T]]: @@ -95,7 +95,7 @@ def given( def when( name: str | StepParser, - converters: dict[str, Callable[[Any], Any]] | None = None, + converters: dict[str, Callable[[str], Any]] | None = None, target_fixture: str | None = None, stacklevel: int = 1, ) -> Callable[[Callable[P, T]], Callable[P, T]]: @@ -114,7 +114,7 @@ def when( def then( name: str | StepParser, - converters: dict[str, Callable[[Any], Any]] | None = None, + converters: dict[str, Callable[[str], Any]] | None = None, target_fixture: str | None = None, stacklevel: int = 1, ) -> Callable[[Callable[P, T]], Callable[P, T]]: @@ -134,7 +134,7 @@ def then( def step( name: str | StepParser, type_: Literal["given", "when", "then"] | None = None, - converters: dict[str, Callable[[Any], Any]] | None = None, + converters: dict[str, Callable[[str], Any]] | None = None, target_fixture: str | None = None, stacklevel: int = 1, ) -> Callable[[Callable[P, T]], Callable[P, T]]: