91 lines
2.4 KiB
Python
91 lines
2.4 KiB
Python
import os
|
|
import tempfile
|
|
import unittest
|
|
from distutils.util import strtobool
|
|
|
|
from transformers.file_utils import _tf_available, _torch_available
|
|
|
|
|
|
CACHE_DIR = os.path.join(tempfile.gettempdir(), "transformers_test")
|
|
|
|
SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy"
|
|
DUMMY_UNKWOWN_IDENTIFIER = "julien-c/dummy-unknown"
|
|
# Used to test Auto{Config, Model, Tokenizer} model_type detection.
|
|
|
|
|
|
def parse_flag_from_env(key, default=False):
|
|
try:
|
|
value = os.environ[key]
|
|
except KeyError:
|
|
# KEY isn't set, default to `default`.
|
|
_value = default
|
|
else:
|
|
# KEY is set, convert it to True or False.
|
|
try:
|
|
_value = strtobool(value)
|
|
except ValueError:
|
|
# More values are supported, but let's keep the message simple.
|
|
raise ValueError("If set, {} must be yes or no.".format(key))
|
|
return _value
|
|
|
|
|
|
_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
|
|
_run_custom_tokenizers = parse_flag_from_env("RUN_CUSTOM_TOKENIZERS", default=False)
|
|
|
|
|
|
def slow(test_case):
|
|
"""
|
|
Decorator marking a test as slow.
|
|
|
|
Slow tests are skipped by default. Set the RUN_SLOW environment variable
|
|
to a truthy value to run them.
|
|
|
|
"""
|
|
if not _run_slow_tests:
|
|
test_case = unittest.skip("test is slow")(test_case)
|
|
return test_case
|
|
|
|
|
|
def custom_tokenizers(test_case):
|
|
"""
|
|
Decorator marking a test for a custom tokenizer.
|
|
|
|
Custom tokenizers require additional dependencies, and are skipped
|
|
by default. Set the RUN_CUSTOM_TOKENIZERS environment variable
|
|
to a truthy value to run them.
|
|
"""
|
|
if not _run_custom_tokenizers:
|
|
test_case = unittest.skip("test of custom tokenizers")(test_case)
|
|
return test_case
|
|
|
|
|
|
def require_torch(test_case):
|
|
"""
|
|
Decorator marking a test that requires PyTorch.
|
|
|
|
These tests are skipped when PyTorch isn't installed.
|
|
|
|
"""
|
|
if not _torch_available:
|
|
test_case = unittest.skip("test requires PyTorch")(test_case)
|
|
return test_case
|
|
|
|
|
|
def require_tf(test_case):
|
|
"""
|
|
Decorator marking a test that requires TensorFlow.
|
|
|
|
These tests are skipped when TensorFlow isn't installed.
|
|
|
|
"""
|
|
if not _tf_available:
|
|
test_case = unittest.skip("test requires TensorFlow")(test_case)
|
|
return test_case
|
|
|
|
|
|
if _torch_available:
|
|
# Set the USE_CUDA environment variable to select a GPU.
|
|
torch_device = "cuda" if parse_flag_from_env("USE_CUDA") else "cpu"
|
|
else:
|
|
torch_device = None
|