Flax testing should not run the full torch test suite (#10725)
* make flax tests pytorch independent * fix typo * finish * improve circle ci * fix return tensors * correct flax test * re-add sentencepiece * last tokenizer fixes * finish maybe now
This commit is contained in:
parent
87d685b8a9
commit
9f8619c6aa
|
@ -91,6 +91,34 @@ jobs:
|
|||
- store_artifacts:
|
||||
path: ~/transformers/reports
|
||||
|
||||
run_tests_torch_and_flax:
|
||||
working_directory: ~/transformers
|
||||
docker:
|
||||
- image: circleci/python:3.6
|
||||
environment:
|
||||
OMP_NUM_THREADS: 1
|
||||
resource_class: xlarge
|
||||
parallelism: 1
|
||||
steps:
|
||||
- checkout
|
||||
- restore_cache:
|
||||
keys:
|
||||
- v0.4-torch_and_flax-{{ checksum "setup.py" }}
|
||||
- v0.4-{{ checksum "setup.py" }}
|
||||
- run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev
|
||||
- run: pip install --upgrade pip
|
||||
- run: pip install .[sklearn,flax,torch,testing,sentencepiece,speech]
|
||||
- run: pip install tapas torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cpu.html
|
||||
- save_cache:
|
||||
key: v0.4-{{ checksum "setup.py" }}
|
||||
paths:
|
||||
- '~/.cache/pip'
|
||||
- run: RUN_PT_FLAX_CROSS_TESTS=1 python -m pytest -n 8 --dist=loadfile -rA -s --make-reports=tests_torch_and_flax ./tests/ -m is_pt_flax_cross_test --durations=0 | tee tests_output.txt
|
||||
- store_artifacts:
|
||||
path: ~/transformers/tests_output.txt
|
||||
- store_artifacts:
|
||||
path: ~/transformers/reports
|
||||
|
||||
run_tests_torch:
|
||||
working_directory: ~/transformers
|
||||
docker:
|
||||
|
@ -159,9 +187,8 @@ jobs:
|
|||
keys:
|
||||
- v0.4-flax-{{ checksum "setup.py" }}
|
||||
- v0.4-{{ checksum "setup.py" }}
|
||||
- run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev
|
||||
- run: pip install --upgrade pip
|
||||
- run: sudo pip install .[flax,sklearn,torch,testing,sentencepiece,speech]
|
||||
- run: sudo pip install .[flax,testing,sentencepiece]
|
||||
- save_cache:
|
||||
key: v0.4-flax-{{ checksum "setup.py" }}
|
||||
paths:
|
||||
|
@ -418,6 +445,7 @@ workflows:
|
|||
- run_examples_torch
|
||||
- run_tests_custom_tokenizers
|
||||
- run_tests_torch_and_tf
|
||||
- run_tests_torch_and_flax
|
||||
- run_tests_torch
|
||||
- run_tests_tf
|
||||
- run_tests_flax
|
||||
|
|
2
setup.py
2
setup.py
|
@ -97,7 +97,7 @@ _deps = [
|
|||
"fastapi",
|
||||
"filelock",
|
||||
"flake8>=3.8.3",
|
||||
"flax>=0.2.2",
|
||||
"flax>=0.3.2",
|
||||
"fugashi>=1.0",
|
||||
"importlib_metadata",
|
||||
"ipadic>=1.0.0,<2.0",
|
||||
|
|
|
@ -10,7 +10,7 @@ deps = {
|
|||
"fastapi": "fastapi",
|
||||
"filelock": "filelock",
|
||||
"flake8": "flake8>=3.8.3",
|
||||
"flax": "flax>=0.2.2",
|
||||
"flax": "flax>=0.3.2",
|
||||
"fugashi": "fugashi>=1.0",
|
||||
"importlib_metadata": "importlib_metadata",
|
||||
"ipadic": "ipadic>=1.0.0,<2.0",
|
||||
|
|
|
@ -80,6 +80,7 @@ def parse_int_from_env(key, default=None):
|
|||
|
||||
_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
|
||||
_run_pt_tf_cross_tests = parse_flag_from_env("RUN_PT_TF_CROSS_TESTS", default=False)
|
||||
_run_pt_flax_cross_tests = parse_flag_from_env("RUN_PT_FLAX_CROSS_TESTS", default=False)
|
||||
_run_custom_tokenizers = parse_flag_from_env("RUN_CUSTOM_TOKENIZERS", default=False)
|
||||
_run_pipeline_tests = parse_flag_from_env("RUN_PIPELINE_TESTS", default=False)
|
||||
_run_git_lfs_tests = parse_flag_from_env("RUN_GIT_LFS_TESTS", default=False)
|
||||
|
@ -105,6 +106,25 @@ def is_pt_tf_cross_test(test_case):
|
|||
return pytest.mark.is_pt_tf_cross_test()(test_case)
|
||||
|
||||
|
||||
def is_pt_flax_cross_test(test_case):
|
||||
"""
|
||||
Decorator marking a test as a test that control interactions between PyTorch and Flax
|
||||
|
||||
PT+FLAX tests are skipped by default and we can run only them by setting RUN_PT_FLAX_CROSS_TESTS environment
|
||||
variable to a truthy value and selecting the is_pt_flax_cross_test pytest mark.
|
||||
|
||||
"""
|
||||
if not _run_pt_flax_cross_tests or not is_torch_available() or not is_flax_available():
|
||||
return unittest.skip("test is PT+FLAX test")(test_case)
|
||||
else:
|
||||
try:
|
||||
import pytest # We don't need a hard dependency on pytest in the main library
|
||||
except ImportError:
|
||||
return test_case
|
||||
else:
|
||||
return pytest.mark.is_pt_flax_cross_test()(test_case)
|
||||
|
||||
|
||||
def is_pipeline_test(test_case):
|
||||
"""
|
||||
Decorator marking a test as a pipeline test.
|
||||
|
|
|
@ -35,6 +35,9 @@ def pytest_configure(config):
|
|||
config.addinivalue_line(
|
||||
"markers", "is_pt_tf_cross_test: mark test to run only when PT and TF interactions are tested"
|
||||
)
|
||||
config.addinivalue_line(
|
||||
"markers", "is_pt_flax_cross_test: mark test to run only when PT and FLAX interactions are tested"
|
||||
)
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
|
|
|
@ -19,7 +19,7 @@ import numpy as np
|
|||
|
||||
import transformers
|
||||
from transformers import is_flax_available, is_torch_available
|
||||
from transformers.testing_utils import require_flax, require_torch
|
||||
from transformers.testing_utils import is_pt_flax_cross_test, require_flax
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
|
@ -60,7 +60,6 @@ def random_attention_mask(shape, rng=None):
|
|||
return attn_mask
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxModelTesterMixin:
|
||||
model_tester = None
|
||||
all_model_classes = ()
|
||||
|
@ -69,7 +68,7 @@ class FlaxModelTesterMixin:
|
|||
diff = np.abs((a - b)).max()
|
||||
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
|
||||
|
||||
@require_torch
|
||||
@is_pt_flax_cross_test
|
||||
def test_equivalence_flax_pytorch(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
|
@ -104,6 +103,7 @@ class FlaxModelTesterMixin:
|
|||
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
|
||||
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 5e-3)
|
||||
|
||||
@require_flax
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
|
@ -121,6 +121,7 @@ class FlaxModelTesterMixin:
|
|||
for output_loaded, output in zip(outputs_loaded, outputs):
|
||||
self.assert_almost_equals(output_loaded, output, 5e-3)
|
||||
|
||||
@require_flax
|
||||
def test_jit_compilation(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
|
@ -143,6 +144,7 @@ class FlaxModelTesterMixin:
|
|||
for jitted_output, output in zip(jitted_outputs, outputs):
|
||||
self.assertEqual(jitted_output.shape, output.shape)
|
||||
|
||||
@require_flax
|
||||
def test_naming_convention(self):
|
||||
for model_class in self.all_model_classes:
|
||||
model_class_name = model_class.__name__
|
||||
|
|
|
@ -24,7 +24,13 @@ from collections import OrderedDict
|
|||
from itertools import takewhile
|
||||
from typing import TYPE_CHECKING, Dict, List, Tuple, Union
|
||||
|
||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast, is_torch_available
|
||||
from transformers import (
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedTokenizerBase,
|
||||
PreTrainedTokenizerFast,
|
||||
is_tf_available,
|
||||
is_torch_available,
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
get_tests_dir,
|
||||
is_pt_tf_cross_test,
|
||||
|
@ -2283,7 +2289,12 @@ class TokenizerTesterMixin:
|
|||
"{} ({}, {})".format(tokenizer.__class__.__name__, pretrained_name, tokenizer.__class__.__name__)
|
||||
):
|
||||
|
||||
returned_tensor = "pt" if is_torch_available() else "tf"
|
||||
if is_torch_available():
|
||||
returned_tensor = "pt"
|
||||
elif is_tf_available():
|
||||
returned_tensor = "tf"
|
||||
else:
|
||||
returned_tensor = "jax"
|
||||
|
||||
if not tokenizer.pad_token or tokenizer.pad_token_id < 0:
|
||||
return
|
||||
|
|
|
@ -21,7 +21,7 @@ from pathlib import Path
|
|||
from shutil import copyfile
|
||||
|
||||
from transformers import BatchEncoding, MarianTokenizer
|
||||
from transformers.file_utils import is_sentencepiece_available, is_torch_available
|
||||
from transformers.file_utils import is_sentencepiece_available, is_tf_available, is_torch_available
|
||||
from transformers.testing_utils import require_sentencepiece
|
||||
|
||||
|
||||
|
@ -36,7 +36,13 @@ SAMPLE_SP = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/t
|
|||
mock_tokenizer_config = {"target_lang": "fi", "source_lang": "en"}
|
||||
zh_code = ">>zh<<"
|
||||
ORG_NAME = "Helsinki-NLP/"
|
||||
FRAMEWORK = "pt" if is_torch_available() else "tf"
|
||||
|
||||
if is_torch_available():
|
||||
FRAMEWORK = "pt"
|
||||
elif is_tf_available():
|
||||
FRAMEWORK = "tf"
|
||||
else:
|
||||
FRAMEWORK = "jax"
|
||||
|
||||
|
||||
@require_sentencepiece
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
import unittest
|
||||
|
||||
from transformers import SPIECE_UNDERLINE, BatchEncoding, T5Tokenizer, T5TokenizerFast
|
||||
from transformers.file_utils import cached_property, is_torch_available
|
||||
from transformers.file_utils import cached_property, is_tf_available, is_torch_available
|
||||
from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_tokenizers
|
||||
|
||||
from .test_tokenization_common import TokenizerTesterMixin
|
||||
|
@ -25,7 +25,12 @@ from .test_tokenization_common import TokenizerTesterMixin
|
|||
|
||||
SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
|
||||
|
||||
FRAMEWORK = "pt" if is_torch_available() else "tf"
|
||||
if is_torch_available():
|
||||
FRAMEWORK = "pt"
|
||||
elif is_tf_available():
|
||||
FRAMEWORK = "tf"
|
||||
else:
|
||||
FRAMEWORK = "jax"
|
||||
|
||||
|
||||
@require_sentencepiece
|
||||
|
@ -157,7 +162,12 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||
expected_src_tokens = [71, 307, 8986, 21, 4505, 1635, 1707, 5, tokenizer.eos_token_id]
|
||||
batch = tokenizer(src_text, padding=True, return_tensors=FRAMEWORK)
|
||||
self.assertIsInstance(batch, BatchEncoding)
|
||||
result = list(batch.input_ids.numpy()[0])
|
||||
|
||||
if FRAMEWORK != "jax":
|
||||
result = list(batch.input_ids.numpy()[0])
|
||||
else:
|
||||
result = list(batch.input_ids.tolist()[0])
|
||||
|
||||
self.assertListEqual(expected_src_tokens, result)
|
||||
|
||||
self.assertEqual((2, 9), batch.input_ids.shape)
|
||||
|
|
Loading…
Reference in New Issue