Flax testing should not run the full torch test suite (#10725)

* make flax tests pytorch independent

* fix typo

* finish

* improve circle ci

* fix return tensors

* correct flax test

* re-add sentencepiece

* last tokenizer fixes

* finish maybe now
This commit is contained in:
Patrick von Platen 2021-03-16 08:05:37 +03:00 committed by GitHub
parent 87d685b8a9
commit 9f8619c6aa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 94 additions and 14 deletions

View File

@ -91,6 +91,34 @@ jobs:
- store_artifacts:
path: ~/transformers/reports
run_tests_torch_and_flax:
working_directory: ~/transformers
docker:
- image: circleci/python:3.6
environment:
OMP_NUM_THREADS: 1
resource_class: xlarge
parallelism: 1
steps:
- checkout
- restore_cache:
keys:
- v0.4-torch_and_flax-{{ checksum "setup.py" }}
- v0.4-{{ checksum "setup.py" }}
- run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev
- run: pip install --upgrade pip
- run: pip install .[sklearn,flax,torch,testing,sentencepiece,speech]
- run: pip install tapas torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cpu.html
- save_cache:
key: v0.4-{{ checksum "setup.py" }}
paths:
- '~/.cache/pip'
- run: RUN_PT_FLAX_CROSS_TESTS=1 python -m pytest -n 8 --dist=loadfile -rA -s --make-reports=tests_torch_and_flax ./tests/ -m is_pt_flax_cross_test --durations=0 | tee tests_output.txt
- store_artifacts:
path: ~/transformers/tests_output.txt
- store_artifacts:
path: ~/transformers/reports
run_tests_torch:
working_directory: ~/transformers
docker:
@ -159,9 +187,8 @@ jobs:
keys:
- v0.4-flax-{{ checksum "setup.py" }}
- v0.4-{{ checksum "setup.py" }}
- run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev
- run: pip install --upgrade pip
- run: sudo pip install .[flax,sklearn,torch,testing,sentencepiece,speech]
- run: sudo pip install .[flax,testing,sentencepiece]
- save_cache:
key: v0.4-flax-{{ checksum "setup.py" }}
paths:
@ -418,6 +445,7 @@ workflows:
- run_examples_torch
- run_tests_custom_tokenizers
- run_tests_torch_and_tf
- run_tests_torch_and_flax
- run_tests_torch
- run_tests_tf
- run_tests_flax

View File

@ -97,7 +97,7 @@ _deps = [
"fastapi",
"filelock",
"flake8>=3.8.3",
"flax>=0.2.2",
"flax>=0.3.2",
"fugashi>=1.0",
"importlib_metadata",
"ipadic>=1.0.0,<2.0",

View File

@ -10,7 +10,7 @@ deps = {
"fastapi": "fastapi",
"filelock": "filelock",
"flake8": "flake8>=3.8.3",
"flax": "flax>=0.2.2",
"flax": "flax>=0.3.2",
"fugashi": "fugashi>=1.0",
"importlib_metadata": "importlib_metadata",
"ipadic": "ipadic>=1.0.0,<2.0",

View File

@ -80,6 +80,7 @@ def parse_int_from_env(key, default=None):
_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
_run_pt_tf_cross_tests = parse_flag_from_env("RUN_PT_TF_CROSS_TESTS", default=False)
_run_pt_flax_cross_tests = parse_flag_from_env("RUN_PT_FLAX_CROSS_TESTS", default=False)
_run_custom_tokenizers = parse_flag_from_env("RUN_CUSTOM_TOKENIZERS", default=False)
_run_pipeline_tests = parse_flag_from_env("RUN_PIPELINE_TESTS", default=False)
_run_git_lfs_tests = parse_flag_from_env("RUN_GIT_LFS_TESTS", default=False)
@ -105,6 +106,25 @@ def is_pt_tf_cross_test(test_case):
return pytest.mark.is_pt_tf_cross_test()(test_case)
def is_pt_flax_cross_test(test_case):
"""
Decorator marking a test as a test that control interactions between PyTorch and Flax
PT+FLAX tests are skipped by default and we can run only them by setting RUN_PT_FLAX_CROSS_TESTS environment
variable to a truthy value and selecting the is_pt_flax_cross_test pytest mark.
"""
if not _run_pt_flax_cross_tests or not is_torch_available() or not is_flax_available():
return unittest.skip("test is PT+FLAX test")(test_case)
else:
try:
import pytest # We don't need a hard dependency on pytest in the main library
except ImportError:
return test_case
else:
return pytest.mark.is_pt_flax_cross_test()(test_case)
def is_pipeline_test(test_case):
"""
Decorator marking a test as a pipeline test.

View File

@ -35,6 +35,9 @@ def pytest_configure(config):
config.addinivalue_line(
"markers", "is_pt_tf_cross_test: mark test to run only when PT and TF interactions are tested"
)
config.addinivalue_line(
"markers", "is_pt_flax_cross_test: mark test to run only when PT and FLAX interactions are tested"
)
def pytest_addoption(parser):

View File

@ -19,7 +19,7 @@ import numpy as np
import transformers
from transformers import is_flax_available, is_torch_available
from transformers.testing_utils import require_flax, require_torch
from transformers.testing_utils import is_pt_flax_cross_test, require_flax
if is_flax_available():
@ -60,7 +60,6 @@ def random_attention_mask(shape, rng=None):
return attn_mask
@require_flax
class FlaxModelTesterMixin:
model_tester = None
all_model_classes = ()
@ -69,7 +68,7 @@ class FlaxModelTesterMixin:
diff = np.abs((a - b)).max()
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
@require_torch
@is_pt_flax_cross_test
def test_equivalence_flax_pytorch(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@ -104,6 +103,7 @@ class FlaxModelTesterMixin:
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 5e-3)
@require_flax
def test_from_pretrained_save_pretrained(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@ -121,6 +121,7 @@ class FlaxModelTesterMixin:
for output_loaded, output in zip(outputs_loaded, outputs):
self.assert_almost_equals(output_loaded, output, 5e-3)
@require_flax
def test_jit_compilation(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@ -143,6 +144,7 @@ class FlaxModelTesterMixin:
for jitted_output, output in zip(jitted_outputs, outputs):
self.assertEqual(jitted_output.shape, output.shape)
@require_flax
def test_naming_convention(self):
for model_class in self.all_model_classes:
model_class_name = model_class.__name__

View File

@ -24,7 +24,13 @@ from collections import OrderedDict
from itertools import takewhile
from typing import TYPE_CHECKING, Dict, List, Tuple, Union
from transformers import PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast, is_torch_available
from transformers import (
PreTrainedTokenizer,
PreTrainedTokenizerBase,
PreTrainedTokenizerFast,
is_tf_available,
is_torch_available,
)
from transformers.testing_utils import (
get_tests_dir,
is_pt_tf_cross_test,
@ -2283,7 +2289,12 @@ class TokenizerTesterMixin:
"{} ({}, {})".format(tokenizer.__class__.__name__, pretrained_name, tokenizer.__class__.__name__)
):
returned_tensor = "pt" if is_torch_available() else "tf"
if is_torch_available():
returned_tensor = "pt"
elif is_tf_available():
returned_tensor = "tf"
else:
returned_tensor = "jax"
if not tokenizer.pad_token or tokenizer.pad_token_id < 0:
return

View File

@ -21,7 +21,7 @@ from pathlib import Path
from shutil import copyfile
from transformers import BatchEncoding, MarianTokenizer
from transformers.file_utils import is_sentencepiece_available, is_torch_available
from transformers.file_utils import is_sentencepiece_available, is_tf_available, is_torch_available
from transformers.testing_utils import require_sentencepiece
@ -36,7 +36,13 @@ SAMPLE_SP = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/t
mock_tokenizer_config = {"target_lang": "fi", "source_lang": "en"}
zh_code = ">>zh<<"
ORG_NAME = "Helsinki-NLP/"
FRAMEWORK = "pt" if is_torch_available() else "tf"
if is_torch_available():
FRAMEWORK = "pt"
elif is_tf_available():
FRAMEWORK = "tf"
else:
FRAMEWORK = "jax"
@require_sentencepiece

View File

@ -17,7 +17,7 @@
import unittest
from transformers import SPIECE_UNDERLINE, BatchEncoding, T5Tokenizer, T5TokenizerFast
from transformers.file_utils import cached_property, is_torch_available
from transformers.file_utils import cached_property, is_tf_available, is_torch_available
from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_tokenizers
from .test_tokenization_common import TokenizerTesterMixin
@ -25,7 +25,12 @@ from .test_tokenization_common import TokenizerTesterMixin
SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
FRAMEWORK = "pt" if is_torch_available() else "tf"
if is_torch_available():
FRAMEWORK = "pt"
elif is_tf_available():
FRAMEWORK = "tf"
else:
FRAMEWORK = "jax"
@require_sentencepiece
@ -157,7 +162,12 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
expected_src_tokens = [71, 307, 8986, 21, 4505, 1635, 1707, 5, tokenizer.eos_token_id]
batch = tokenizer(src_text, padding=True, return_tensors=FRAMEWORK)
self.assertIsInstance(batch, BatchEncoding)
result = list(batch.input_ids.numpy()[0])
if FRAMEWORK != "jax":
result = list(batch.input_ids.numpy()[0])
else:
result = list(batch.input_ids.tolist()[0])
self.assertListEqual(expected_src_tokens, result)
self.assertEqual((2, 9), batch.input_ids.shape)