From 1e402b957d96597e5e47c06da5671ccec09621cc Mon Sep 17 00:00:00 2001 From: Titus <9048635+Titus-von-Koeller@users.noreply.github.com> Date: Thu, 15 Feb 2024 21:53:09 -0300 Subject: [PATCH] add test marker to run all tests with @require_bitsandbytes (#28278) --- pyproject.toml | 1 + src/transformers/testing_utils.py | 12 ++++++++++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a7e1720022..d66b89769c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,4 +32,5 @@ doctest_optionflags="NUMBER NORMALIZE_WHITESPACE ELLIPSIS" doctest_glob="**/*.md" markers = [ "flash_attn_test: marks tests related to flash attention (deselect with '-m \"not flash_attn_test\"')", + "bitsandbytes: select (or deselect with `not`) bitsandbytes integration tests", ] \ No newline at end of file diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 0ff7e718af..50e178fbea 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -966,9 +966,17 @@ def require_aqlm(test_case): def require_bitsandbytes(test_case): """ - Decorator for bits and bytes (bnb) dependency + Decorator marking a test that requires the bitsandbytes library. Will be skipped when the library or its hard dependency torch is not installed. """ - return unittest.skipUnless(is_bitsandbytes_available(), "test requires bnb")(test_case) + if is_bitsandbytes_available() and is_torch_available(): + try: + import pytest + + return pytest.mark.bitsandbytes(test_case) + except ImportError: + return test_case + else: + return unittest.skip("test requires bitsandbytes and torch")(test_case) def require_optimum(test_case):