Merge pull request #646 from JP-Ellis/fix/decorator-type-hinting

fix(typing): improve decorator type hinting
This commit is contained in:
Alessio Bogon 2023-12-02 22:07:41 +01:00 committed by GitHub
commit 8a694ff1e3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 49 additions and 30 deletions

View File

@ -1,16 +1,15 @@
"""Pytest plugin entry point. Used for any fixtures needed."""
from __future__ import annotations
from typing import TYPE_CHECKING, Callable, cast
from typing import TYPE_CHECKING, Any, Callable, Generator, TypeVar, cast
import pytest
from typing_extensions import ParamSpec
from . import cucumber_json, generation, gherkin_terminal_reporter, given, reporting, then, when
from .utils import CONFIG_STACK
if TYPE_CHECKING:
from typing import Any, Generator
from _pytest.config import Config, PytestPluginManager
from _pytest.config.argparsing import Parser
from _pytest.fixtures import FixtureRequest
@ -21,6 +20,10 @@ if TYPE_CHECKING:
from .parser import Feature, Scenario, Step
P = ParamSpec("P")
T = TypeVar("T")
def pytest_addhooks(pluginmanager: PytestPluginManager) -> None:
"""Register plugin hooks."""
from pytest_bdd import hooks
@ -94,7 +97,7 @@ def pytest_bdd_step_error(
feature: Feature,
scenario: Scenario,
step: Step,
step_func: Callable,
step_func: Callable[..., Any],
step_func_args: dict,
exception: Exception,
) -> None:
@ -103,7 +106,11 @@ def pytest_bdd_step_error(
@pytest.hookimpl(tryfirst=True)
def pytest_bdd_before_step(
request: FixtureRequest, feature: Feature, scenario: Scenario, step: Step, step_func: Callable
request: FixtureRequest,
feature: Feature,
scenario: Scenario,
step: Step,
step_func: Callable[..., Any],
) -> None:
reporting.before_step(request, feature, scenario, step, step_func)
@ -114,7 +121,7 @@ def pytest_bdd_after_step(
feature: Feature,
scenario: Scenario,
step: Step,
step_func: Callable,
step_func: Callable[..., Any],
step_func_args: dict[str, Any],
) -> None:
reporting.after_step(request, feature, scenario, step, step_func, step_func_args)
@ -124,7 +131,7 @@ def pytest_cmdline_main(config: Config) -> int | None:
return generation.cmdline_main(config)
def pytest_bdd_apply_tag(tag: str, function: Callable) -> Callable:
def pytest_bdd_apply_tag(tag: str, function: Callable[P, T]) -> Callable[P, T]:
mark = getattr(pytest.mark, tag)
marked = mark(function)
return cast(Callable, marked)
return cast(Callable[P, T], marked)

View File

@ -155,7 +155,7 @@ def step_error(
feature: Feature,
scenario: Scenario,
step: Step,
step_func: Callable,
step_func: Callable[..., Any],
step_func_args: dict,
exception: Exception,
) -> None:
@ -163,7 +163,13 @@ def step_error(
request.node.__scenario_report__.fail()
def before_step(request: FixtureRequest, feature: Feature, scenario: Scenario, step: Step, step_func: Callable) -> None:
def before_step(
request: FixtureRequest,
feature: Feature,
scenario: Scenario,
step: Step,
step_func: Callable[..., Any],
) -> None:
"""Store step start time."""
request.node.__scenario_report__.add_step_report(StepReport(step=step))

View File

@ -16,11 +16,12 @@ import contextlib
import logging
import os
import re
from typing import TYPE_CHECKING, Callable, Iterator, cast
from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, TypeVar, cast
import pytest
from _pytest.fixtures import FixtureDef, FixtureManager, FixtureRequest, call_fixture_func
from _pytest.nodes import iterparentnodeids
from typing_extensions import ParamSpec
from . import exceptions
from .feature import get_feature, get_features
@ -28,12 +29,12 @@ from .steps import StepFunctionContext, get_step_fixture_name, inject_fixture
from .utils import CONFIG_STACK, get_args, get_caller_module_locals, get_caller_module_path
if TYPE_CHECKING:
from typing import Any, Iterable
from _pytest.mark.structures import ParameterSet
from .parser import Feature, Scenario, ScenarioTemplate, Step
P = ParamSpec("P")
T = TypeVar("T")
logger = logging.getLogger(__name__)
@ -197,14 +198,14 @@ def _execute_scenario(feature: Feature, scenario: Scenario, request: FixtureRequ
def _get_scenario_decorator(
feature: Feature, feature_name: str, templated_scenario: ScenarioTemplate, scenario_name: str
) -> Callable[[Callable], Callable]:
) -> Callable[[Callable[P, T]], Callable[P, T]]:
# HACK: Ideally we would use `def decorator(fn)`, but we want to return a custom exception
# when the decorator is misused.
# Pytest inspect the signature to determine the required fixtures, and in that case it would look
# for a fixture called "fn" that doesn't exist (if it exists then it's even worse).
# It will error with a "fixture 'fn' not found" message instead.
# We can avoid this hack by using a pytest hook and check for misuse instead.
def decorator(*args: Callable) -> Callable:
def decorator(*args: Callable[P, T]) -> Callable[P, T]:
if not args:
raise exceptions.ScenarioIsDecoratorOnly(
"scenario function can only be used as a decorator. Refer to the documentation."
@ -236,7 +237,7 @@ def _get_scenario_decorator(
scenario_wrapper.__doc__ = f"{feature_name}: {scenario_name}"
scenario_wrapper.__scenario__ = templated_scenario
return cast(Callable, scenario_wrapper)
return cast(Callable[P, T], scenario_wrapper)
return decorator
@ -251,8 +252,11 @@ def collect_example_parametrizations(
def scenario(
feature_name: str, scenario_name: str, encoding: str = "utf-8", features_base_dir=None
) -> Callable[[Callable], Callable]:
feature_name: str,
scenario_name: str,
encoding: str = "utf-8",
features_base_dir: str | None = None,
) -> Callable[[Callable[P, T]], Callable[P, T]]:
"""Scenario decorator.
:param str feature_name: Feature file name. Absolute or relative to the configured feature base path.

View File

@ -43,13 +43,15 @@ from typing import Any, Callable, Iterable, Literal, TypeVar
import pytest
from _pytest.fixtures import FixtureDef, FixtureRequest
from typing_extensions import ParamSpec
from .parser import Step
from .parsers import StepParser, get_parser
from .types import GIVEN, THEN, WHEN
from .utils import get_caller_module_locals
TCallable = TypeVar("TCallable", bound=Callable[..., Any])
P = ParamSpec("P")
T = TypeVar("T")
@enum.unique
@ -63,7 +65,7 @@ class StepFunctionContext:
type: Literal["given", "when", "then"] | None
step_func: Callable[..., Any]
parser: StepParser
converters: dict[str, Callable[..., Any]] = field(default_factory=dict)
converters: dict[str, Callable[[str], Any]] = field(default_factory=dict)
target_fixture: str | None = None
@ -74,10 +76,10 @@ def get_step_fixture_name(step: Step) -> str:
def given(
name: str | StepParser,
converters: dict[str, Callable] | None = None,
converters: dict[str, Callable[[str], Any]] | None = None,
target_fixture: str | None = None,
stacklevel: int = 1,
) -> Callable:
) -> Callable[[Callable[P, T]], Callable[P, T]]:
"""Given step decorator.
:param name: Step name or a parser object.
@ -93,10 +95,10 @@ def given(
def when(
name: str | StepParser,
converters: dict[str, Callable] | None = None,
converters: dict[str, Callable[[str], Any]] | None = None,
target_fixture: str | None = None,
stacklevel: int = 1,
) -> Callable:
) -> Callable[[Callable[P, T]], Callable[P, T]]:
"""When step decorator.
:param name: Step name or a parser object.
@ -112,10 +114,10 @@ def when(
def then(
name: str | StepParser,
converters: dict[str, Callable] | None = None,
converters: dict[str, Callable[[str], Any]] | None = None,
target_fixture: str | None = None,
stacklevel: int = 1,
) -> Callable:
) -> Callable[[Callable[P, T]], Callable[P, T]]:
"""Then step decorator.
:param name: Step name or a parser object.
@ -132,10 +134,10 @@ def then(
def step(
name: str | StepParser,
type_: Literal["given", "when", "then"] | None = None,
converters: dict[str, Callable] | None = None,
converters: dict[str, Callable[[str], Any]] | None = None,
target_fixture: str | None = None,
stacklevel: int = 1,
) -> Callable[[TCallable], TCallable]:
) -> Callable[[Callable[P, T]], Callable[P, T]]:
"""Generic step decorator.
:param name: Step name as in the feature file.
@ -155,7 +157,7 @@ def step(
if converters is None:
converters = {}
def decorator(func: TCallable) -> TCallable:
def decorator(func: Callable[P, T]) -> Callable[P, T]:
parser = get_parser(name)
context = StepFunctionContext(

View File

@ -19,7 +19,7 @@ T = TypeVar("T")
CONFIG_STACK: list[Config] = []
def get_args(func: Callable) -> list[str]:
def get_args(func: Callable[..., Any]) -> list[str]:
"""Get a list of argument names for a function.
:param func: The function to inspect.