Use random_attention_mask for TF tests (#16517)
* use random_attention_mask for TF tests * Fix for TFCLIP test (for now). Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
823dbf8a41
commit
2199382dfd
|
@ -21,7 +21,7 @@ from transformers import is_tf_available, {{cookiecutter.camelcase_modelname}}Co
|
|||
from transformers.testing_utils import require_tf, slow
|
||||
|
||||
from ..test_configuration_common import ConfigTester
|
||||
from ..test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
|
||||
from ..test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
|
@ -92,7 +92,7 @@ class TF{{cookiecutter.camelcase_modelname}}ModelTester:
|
|||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
|
|
|
@ -21,7 +21,7 @@ from transformers.models.auto import get_values
|
|||
from transformers.testing_utils import require_tf, slow
|
||||
|
||||
from ..test_configuration_common import ConfigTester
|
||||
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
||||
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
|
@ -96,7 +96,7 @@ class TFAlbertModelTester:
|
|||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
|
|
|
@ -21,7 +21,7 @@ from transformers.models.auto import get_values
|
|||
from transformers.testing_utils import require_tf, slow
|
||||
|
||||
from ..test_configuration_common import ConfigTester
|
||||
from ..test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
|
||||
from ..test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||
from ..utils.test_modeling_tf_core import TFCoreModelTesterMixin
|
||||
|
||||
|
||||
|
@ -96,7 +96,7 @@ class TFBertModelTester:
|
|||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
|
|
|
@ -301,6 +301,12 @@ class TFCLIPTextModelTester:
|
|||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||
# make sure the first token has attention mask `1` to ensure that, after combining the causal mask, there
|
||||
# is still at least one token being attended to for each batch.
|
||||
# TODO: Change `random_attention_mask` in PT/TF/Flax common test file, after a discussion with the team.
|
||||
input_mask = tf.concat(
|
||||
[tf.ones_like(input_mask[:, :1], dtype=input_mask.dtype), input_mask[:, 1:]], axis=-1
|
||||
)
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ from transformers import ConvBertConfig, is_tf_available
|
|||
from transformers.testing_utils import require_tf, slow
|
||||
|
||||
from ..test_configuration_common import ConfigTester
|
||||
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
||||
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
|
@ -94,7 +94,7 @@ class TFConvBertModelTester:
|
|||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
|
|
|
@ -20,7 +20,7 @@ from transformers import CTRLConfig, is_tf_available
|
|||
from transformers.testing_utils import require_tf, slow
|
||||
|
||||
from ..test_configuration_common import ConfigTester
|
||||
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
||||
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
|
@ -69,7 +69,7 @@ class TFCTRLModelTester(object):
|
|||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
|
|
|
@ -20,7 +20,7 @@ from transformers import DebertaConfig, is_tf_available
|
|||
from transformers.testing_utils import require_tf, slow
|
||||
|
||||
from ..test_configuration_common import ConfigTester
|
||||
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
||||
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
|
@ -92,7 +92,7 @@ class TFDebertaModelTester:
|
|||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
|
|
|
@ -20,7 +20,7 @@ from transformers import DebertaV2Config, is_tf_available
|
|||
from transformers.testing_utils import require_tf, slow
|
||||
|
||||
from ..test_configuration_common import ConfigTester
|
||||
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
||||
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
|
@ -95,7 +95,7 @@ class TFDebertaV2ModelTester:
|
|||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
|
|
|
@ -20,7 +20,7 @@ from transformers import DistilBertConfig, is_tf_available
|
|||
from transformers.testing_utils import require_tf, slow
|
||||
|
||||
from ..test_configuration_common import ConfigTester
|
||||
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
||||
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
|
@ -70,7 +70,7 @@ class TFDistilBertModelTester:
|
|||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||
|
||||
sequence_labels = None
|
||||
token_labels = None
|
||||
|
|
|
@ -19,7 +19,7 @@ from transformers import is_tf_available
|
|||
from transformers.testing_utils import require_tf, slow
|
||||
|
||||
from ..test_configuration_common import ConfigTester
|
||||
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
||||
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
|
@ -94,9 +94,8 @@ class TFDPRModelTester:
|
|||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = ids_tensor(
|
||||
[self.batch_size, self.seq_length], vocab_size=2
|
||||
) # follow test_modeling_tf_ctrl.py
|
||||
# follow test_modeling_tf_ctrl.py
|
||||
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
|
|
|
@ -20,7 +20,7 @@ from transformers import ElectraConfig, is_tf_available
|
|||
from transformers.testing_utils import require_tf, slow
|
||||
|
||||
from ..test_configuration_common import ConfigTester
|
||||
from ..test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
|
||||
from ..test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
|
@ -71,7 +71,7 @@ class TFElectraModelTester:
|
|||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
|
|
|
@ -19,7 +19,7 @@ from transformers import is_tf_available
|
|||
from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow
|
||||
|
||||
from ..test_configuration_common import ConfigTester
|
||||
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
||||
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
|
@ -75,7 +75,7 @@ class TFFlaubertModelTester:
|
|||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
input_mask = ids_tensor([self.batch_size, self.seq_length], 2, dtype=tf.float32)
|
||||
input_mask = random_attention_mask([self.batch_size, self.seq_length], dtype=tf.float32)
|
||||
|
||||
input_lengths = None
|
||||
if self.use_input_lengths:
|
||||
|
|
|
@ -20,7 +20,7 @@ from transformers import FunnelConfig, is_tf_available
|
|||
from transformers.testing_utils import require_tf
|
||||
|
||||
from ..test_configuration_common import ConfigTester
|
||||
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
||||
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
|
@ -111,7 +111,7 @@ class TFFunnelModelTester:
|
|||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
|
|
|
@ -19,7 +19,7 @@ from transformers import GPT2Config, is_tf_available
|
|||
from transformers.testing_utils import require_tf, slow
|
||||
|
||||
from ..test_configuration_common import ConfigTester
|
||||
from ..test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
|
||||
from ..test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||
from ..utils.test_modeling_tf_core import TFCoreModelTesterMixin
|
||||
|
||||
|
||||
|
@ -74,7 +74,7 @@ class TFGPT2ModelTester:
|
|||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
|
|
|
@ -20,7 +20,7 @@ from transformers import AutoTokenizer, GPTJConfig, is_tf_available
|
|||
from transformers.testing_utils import require_tf, slow, tooslow
|
||||
|
||||
from ..test_configuration_common import ConfigTester
|
||||
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
||||
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
|
||||
from ..utils.test_modeling_tf_core import TFCoreModelTesterMixin
|
||||
|
||||
|
||||
|
@ -70,7 +70,7 @@ class TFGPTJModelTester:
|
|||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
|
|
|
@ -21,7 +21,7 @@ from transformers import LayoutLMConfig, is_tf_available
|
|||
from transformers.testing_utils import require_tf, slow
|
||||
|
||||
from ..test_configuration_common import ConfigTester
|
||||
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
||||
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
|
@ -107,7 +107,7 @@ class TFLayoutLMModelTester:
|
|||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
|
|
|
@ -20,7 +20,7 @@ from transformers import is_tf_available
|
|||
from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow
|
||||
|
||||
from ..test_configuration_common import ConfigTester
|
||||
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
||||
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
|
@ -79,7 +79,7 @@ class TFLongformerModelTester:
|
|||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
|
|
|
@ -23,7 +23,7 @@ from transformers import LxmertConfig, is_tf_available
|
|||
from transformers.testing_utils import require_tf, slow
|
||||
|
||||
from ..test_configuration_common import ConfigTester
|
||||
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
||||
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
|
@ -124,7 +124,7 @@ class TFLxmertModelTester(object):
|
|||
|
||||
input_mask = None
|
||||
if self.use_lang_mask:
|
||||
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
||||
|
|
|
@ -20,7 +20,7 @@ from transformers import MobileBertConfig, is_tf_available
|
|||
from transformers.testing_utils import require_tf, slow
|
||||
|
||||
from ..test_configuration_common import ConfigTester
|
||||
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
||||
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
|
@ -114,7 +114,7 @@ class TFMobileBertModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
|
|
|
@ -20,7 +20,7 @@ from transformers import MPNetConfig, is_tf_available
|
|||
from transformers.testing_utils import require_tf, slow
|
||||
|
||||
from ..test_configuration_common import ConfigTester
|
||||
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
||||
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
|
@ -90,7 +90,7 @@ class TFMPNetModelTester:
|
|||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||
|
||||
sequence_labels = None
|
||||
token_labels = None
|
||||
|
|
|
@ -20,7 +20,7 @@ from transformers import OpenAIGPTConfig, is_tf_available
|
|||
from transformers.testing_utils import require_tf, slow
|
||||
|
||||
from ..test_configuration_common import ConfigTester
|
||||
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
||||
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
|
@ -70,7 +70,7 @@ class TFOpenAIGPTModelTester:
|
|||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
|
|
|
@ -20,7 +20,7 @@ from transformers import RemBertConfig, is_tf_available
|
|||
from transformers.testing_utils import require_tf, slow
|
||||
|
||||
from ..test_configuration_common import ConfigTester
|
||||
from ..test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
|
||||
from ..test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
|
@ -95,7 +95,7 @@ class TFRemBertModelTester:
|
|||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
|
|
|
@ -20,7 +20,7 @@ from transformers import RobertaConfig, is_tf_available
|
|||
from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow
|
||||
|
||||
from ..test_configuration_common import ConfigTester
|
||||
from ..test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
|
||||
from ..test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
|
@ -72,7 +72,7 @@ class TFRobertaModelTester:
|
|||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
|
|
|
@ -20,7 +20,7 @@ from transformers import RoFormerConfig, is_tf_available
|
|||
from transformers.testing_utils import require_tf, slow
|
||||
|
||||
from ..test_configuration_common import ConfigTester
|
||||
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
||||
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
|
@ -95,7 +95,7 @@ class TFRoFormerModelTester:
|
|||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
|
|
|
@ -20,7 +20,7 @@ from transformers.testing_utils import require_sentencepiece, require_tf, requir
|
|||
from transformers.utils import cached_property
|
||||
|
||||
from ..test_configuration_common import ConfigTester
|
||||
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
||||
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
|
@ -58,7 +58,7 @@ class TFT5ModelTester:
|
|||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||
|
||||
token_labels = None
|
||||
if self.use_labels:
|
||||
|
|
|
@ -38,7 +38,7 @@ from transformers.testing_utils import require_tensorflow_probability, require_t
|
|||
from transformers.utils import cached_property
|
||||
|
||||
from ..test_configuration_common import ConfigTester
|
||||
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
||||
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
|
@ -158,7 +158,7 @@ class TFTapasModelTester:
|
|||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||
|
||||
token_type_ids = []
|
||||
for type_vocab_size in self.type_vocab_sizes:
|
||||
|
|
|
@ -1440,7 +1440,7 @@ def ids_tensor(shape, vocab_size, rng=None, name=None, dtype=None):
|
|||
def random_attention_mask(shape, rng=None, name=None, dtype=None):
|
||||
attn_mask = ids_tensor(shape, vocab_size=2, rng=None, name=None, dtype=dtype)
|
||||
# make sure that at least one token is attended to for each batch
|
||||
attn_mask = tf.concat([tf.constant(value=1, shape=(shape[0], 1), dtype=dtype), attn_mask[:, 1:]], axis=1)
|
||||
attn_mask = tf.concat([attn_mask[:, :-1], tf.ones_like(attn_mask[:, -1:], dtype=dtype)], axis=-1)
|
||||
return attn_mask
|
||||
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ from transformers import is_tf_available
|
|||
from transformers.testing_utils import require_tf, slow
|
||||
|
||||
from ..test_configuration_common import ConfigTester
|
||||
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
||||
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
|
@ -75,7 +75,7 @@ class TFXLMModelTester:
|
|||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
input_mask = ids_tensor([self.batch_size, self.seq_length], 2, dtype=tf.float32)
|
||||
input_mask = random_attention_mask([self.batch_size, self.seq_length], dtype=tf.float32)
|
||||
|
||||
input_lengths = None
|
||||
if self.use_input_lengths:
|
||||
|
|
|
@ -22,7 +22,7 @@ from transformers import XLNetConfig, is_tf_available
|
|||
from transformers.testing_utils import require_tf, slow
|
||||
|
||||
from ..test_configuration_common import ConfigTester
|
||||
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
||||
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
|
@ -75,7 +75,7 @@ class TFXLNetModelTester:
|
|||
input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
input_ids_2 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
segment_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
||||
input_mask = ids_tensor([self.batch_size, self.seq_length], 2, dtype=tf.float32)
|
||||
input_mask = random_attention_mask([self.batch_size, self.seq_length], dtype=tf.float32)
|
||||
|
||||
input_ids_q = ids_tensor([self.batch_size, self.seq_length + 1], self.vocab_size)
|
||||
perm_mask = tf.zeros((self.batch_size, self.seq_length + 1, self.seq_length), dtype=tf.float32)
|
||||
|
|
Loading…
Reference in New Issue