Fix tf.concatenate + test past_key_values for TF models (#15774)

* fix wrong method name tf.concatenate

* add tests related to causal LM / decoder

* make style and quality

* clean-up

* Fix TFBertModel's extended_attention_mask when past_key_values is provided

* Fix tests

* fix copies

* More tf.int8 -> tf.int32 in TF test template

* clean-up

* Update TF test template

* revert the previous commit + update the TF test template

* Fix TF template extended_attention_mask when past_key_values is provided

* Fix some styles manually

* clean-up

* Fix ValueError: too many values to unpack in the test

* Fix more: too many values to unpack in the test

* Add a comment for extended_attention_mask when there is past_key_values

* Fix TFElectra extended_attention_mask when past_key_values is provided

* Add tests to other TF models

* Fix for TF Electra test: add prepare_config_and_inputs_for_decoder

* Fix not passing training arg to lm_head in TFRobertaForCausalLM

* Fix tests (with past) for TF Roberta

* add testing for pask_key_values for TFElectra model

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2022-02-25 17:11:46 +01:00 committed by GitHub
parent 4818bf7aed
commit 8635407bc7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 1851 additions and 84 deletions

View File

@ -274,8 +274,8 @@ class TFBertSelfAttention(tf.keras.layers.Layer):
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size)
value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size)
key_layer = tf.concatenate([past_key_value[0], key_layer], dim=2)
value_layer = tf.concatenate([past_key_value[1], value_layer], dim=2)
key_layer = tf.concat([past_key_value[0], key_layer], axis=2)
value_layer = tf.concat([past_key_value[1], value_layer], axis=2)
else:
key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size)
value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size)
@ -817,6 +817,9 @@ class TFBertMainLayer(tf.keras.layers.Layer):
extended_attention_mask = tf.reshape(
extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2])
)
if inputs["past_key_values"][0] is not None:
# attention_mask needs to be sliced to the shape `[batch_size, 1, from_seq_length - cached_seq_length, to_seq_length]
extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :]
else:
extended_attention_mask = tf.reshape(
inputs["attention_mask"], (attention_mask_shape[0], 1, 1, attention_mask_shape[1])

View File

@ -140,8 +140,8 @@ class TFElectraSelfAttention(tf.keras.layers.Layer):
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size)
value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size)
key_layer = tf.concatenate([past_key_value[0], key_layer], dim=2)
value_layer = tf.concatenate([past_key_value[1], value_layer], dim=2)
key_layer = tf.concat([past_key_value[0], key_layer], axis=2)
value_layer = tf.concat([past_key_value[1], value_layer], axis=2)
else:
key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size)
value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size)
@ -673,6 +673,8 @@ class TFElectraMainLayer(tf.keras.layers.Layer):
extended_attention_mask = tf.reshape(
extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2])
)
if past_key_values_length > 0:
extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :]
else:
extended_attention_mask = tf.reshape(
attention_mask, (attention_mask_shape[0], 1, 1, attention_mask_shape[1])

View File

@ -252,8 +252,8 @@ class TFLayoutLMSelfAttention(tf.keras.layers.Layer):
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size)
value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size)
key_layer = tf.concatenate([past_key_value[0], key_layer], dim=2)
value_layer = tf.concatenate([past_key_value[1], value_layer], dim=2)
key_layer = tf.concat([past_key_value[0], key_layer], axis=2)
value_layer = tf.concat([past_key_value[1], value_layer], axis=2)
else:
key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size)
value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size)

View File

@ -212,8 +212,8 @@ class TFRemBertSelfAttention(tf.keras.layers.Layer):
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size)
value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size)
key_layer = tf.concatenate([past_key_value[0], key_layer], dim=2)
value_layer = tf.concatenate([past_key_value[1], value_layer], dim=2)
key_layer = tf.concat([past_key_value[0], key_layer], axis=2)
value_layer = tf.concat([past_key_value[1], value_layer], axis=2)
else:
key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size)
value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size)
@ -740,6 +740,9 @@ class TFRemBertMainLayer(tf.keras.layers.Layer):
extended_attention_mask = tf.reshape(
extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2])
)
if inputs["past_key_values"][0] is not None:
# attention_mask needs to be sliced to the shape `[batch_size, 1, from_seq_length - cached_seq_length, to_seq_length]
extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :]
else:
extended_attention_mask = tf.reshape(
inputs["attention_mask"], (attention_mask_shape[0], 1, 1, attention_mask_shape[1])

View File

@ -261,8 +261,8 @@ class TFRobertaSelfAttention(tf.keras.layers.Layer):
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size)
value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size)
key_layer = tf.concatenate([past_key_value[0], key_layer], dim=2)
value_layer = tf.concatenate([past_key_value[1], value_layer], dim=2)
key_layer = tf.concat([past_key_value[0], key_layer], axis=2)
value_layer = tf.concat([past_key_value[1], value_layer], axis=2)
else:
key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size)
value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size)
@ -704,6 +704,9 @@ class TFRobertaMainLayer(tf.keras.layers.Layer):
extended_attention_mask = tf.reshape(
extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2])
)
if inputs["past_key_values"][0] is not None:
# attention_mask needs to be sliced to the shape `[batch_size, 1, from_seq_length - cached_seq_length, to_seq_length]
extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :]
else:
extended_attention_mask = tf.reshape(
inputs["attention_mask"], (attention_mask_shape[0], 1, 1, attention_mask_shape[1])
@ -1305,7 +1308,7 @@ class TFRobertaForCausalLM(TFRobertaPreTrainedModel, TFCausalLanguageModelingLos
)
sequence_output = outputs[0]
logits = self.lm_head(hidden_states=sequence_output)
logits = self.lm_head(hidden_states=sequence_output, training=inputs["training"])
loss = None
if inputs["labels"] is not None:

View File

@ -317,8 +317,8 @@ class TFTapasSelfAttention(tf.keras.layers.Layer):
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size)
value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size)
key_layer = tf.concatenate([past_key_value[0], key_layer], dim=2)
value_layer = tf.concatenate([past_key_value[1], value_layer], dim=2)
key_layer = tf.concat([past_key_value[0], key_layer], axis=2)
value_layer = tf.concat([past_key_value[1], value_layer], axis=2)
else:
key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size)
value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size)

View File

@ -32,7 +32,6 @@ from ...file_utils import (
)
from ...modeling_tf_outputs import (
TFBaseModelOutputWithPastAndCrossAttentions,
TFBaseModelOutputWithPoolingAndCrossAttentions,
TFCausalLMOutputWithCrossAttentions,
TFMaskedLMOutput,
TFMultipleChoiceModelOutput,
@ -216,8 +215,8 @@ class TF{{cookiecutter.camelcase_modelname}}SelfAttention(tf.keras.layers.Layer)
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size)
value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size)
key_layer = tf.concatenate([past_key_value[0], key_layer], dim=2)
value_layer = tf.concatenate([past_key_value[1], value_layer], dim=2)
key_layer = tf.concat([past_key_value[0], key_layer], axis=2)
value_layer = tf.concat([past_key_value[1], value_layer], axis=2)
else:
key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size)
value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size)
@ -654,7 +653,7 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
return_dict: Optional[bool] = None,
training: bool = False,
**kwargs,
) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]:
) -> Union[TFBaseModelOutputWithPastAndCrossAttentions, Tuple[tf.Tensor]]:
inputs = input_processing(
func=self.call,
config=self.config,
@ -734,6 +733,9 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
extended_attention_mask = tf.reshape(
extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2])
)
if inputs["past_key_values"][0] is not None:
# attention_mask needs to be sliced to the shape `[batch_size, 1, from_seq_length - cached_seq_length, to_seq_length]
extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :]
else:
extended_attention_mask = tf.reshape(
inputs["attention_mask"], (attention_mask_shape[0], 1, 1, attention_mask_shape[1])
@ -945,7 +947,7 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TFBaseModelOutputWithPoolingAndCrossAttentions,
output_type=TFBaseModelOutputWithPastAndCrossAttentions,
config_class=_CONFIG_FOR_DOC,
)
def call(
@ -965,7 +967,7 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
return_dict: Optional[bool] = None,
training: Optional[bool] = False,
**kwargs,
) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]:
) -> Union[TFBaseModelOutputWithPastAndCrossAttentions, Tuple[tf.Tensor]]:
r"""
encoder_hidden_states (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
@ -1024,10 +1026,9 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
return outputs
# Copied from transformers.models.bert.modeling_tf_bert.TFBertModel.serving_output
def serving_output(
self, output: TFBaseModelOutputWithPoolingAndCrossAttentions
) -> TFBaseModelOutputWithPoolingAndCrossAttentions:
self, output: TFBaseModelOutputWithPastAndCrossAttentions
) -> TFBaseModelOutputWithPastAndCrossAttentions:
output_cache = self.config.use_cache and self.config.is_decoder
pkv = tf.convert_to_tensor(output.past_key_values) if output_cache else None
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
@ -1036,9 +1037,8 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
if not (self.config.output_attentions and self.config.add_cross_attention):
cross_attns = None
return TFBaseModelOutputWithPoolingAndCrossAttentions(
return TFBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=output.last_hidden_state,
pooler_output=output.pooler_output,
past_key_values=pkv,
hidden_states=hs,
attentions=attns,

View File

@ -163,10 +163,59 @@ class TF{{cookiecutter.camelcase_modelname}}ModelTester:
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
def create_and_check_lm_head(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def create_and_check_causal_lm_base_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.is_decoder = True
model = TF{{cookiecutter.camelcase_modelname}}Model(config=config)
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
result = model(inputs)
inputs = [input_ids, input_mask]
result = model(inputs)
result = model(input_ids)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
def create_and_check_model_as_decoder(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
):
config.add_cross_attention = True
model = TF{{cookiecutter.camelcase_modelname}}Model(config=config)
inputs = {
"input_ids": input_ids,
"attention_mask": input_mask,
"token_type_ids": token_type_ids,
"encoder_hidden_states": encoder_hidden_states,
"encoder_attention_mask": encoder_attention_mask,
}
result = model(inputs)
inputs = [input_ids, input_mask]
result = model(inputs, token_type_ids=token_type_ids, encoder_hidden_states=encoder_hidden_states)
# Also check the case where encoder outputs are not passed
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
def create_and_check_causal_lm_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.is_decoder = True
model = TF{{cookiecutter.camelcase_modelname}}ForCausalLM(config=config)
inputs = {
"input_ids": input_ids,
@ -178,6 +227,260 @@ class TF{{cookiecutter.camelcase_modelname}}ModelTester:
list(prediction_scores.numpy().shape), [self.batch_size, self.seq_length, self.vocab_size]
)
def create_and_check_causal_lm_model_as_decoder(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
):
config.add_cross_attention = True
model = TF{{cookiecutter.camelcase_modelname}}ForCausalLM(config=config)
inputs = {
"input_ids": input_ids,
"attention_mask": input_mask,
"token_type_ids": token_type_ids,
"encoder_hidden_states": encoder_hidden_states,
"encoder_attention_mask": encoder_attention_mask,
}
result = model(inputs)
inputs = [input_ids, input_mask]
result = model(inputs, token_type_ids=token_type_ids, encoder_hidden_states=encoder_hidden_states)
prediction_scores = result["logits"]
self.parent.assertListEqual(
list(prediction_scores.numpy().shape), [self.batch_size, self.seq_length, self.vocab_size]
)
def create_and_check_causal_lm_model_past(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
):
config.is_decoder = True
model = TF{{cookiecutter.camelcase_modelname}}ForCausalLM(config=config)
# first forward pass
outputs = model(input_ids, use_cache=True)
outputs_use_cache_conf = model(input_ids)
outputs_no_past = model(input_ids, use_cache=False)
self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
past_key_values = outputs.past_key_values
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
# append to next input_ids and attn_mask
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
output_from_no_past = model(next_input_ids, output_hidden_states=True).hidden_states[0]
output_from_past = model(
next_tokens, past_key_values=past_key_values, output_hidden_states=True
).hidden_states[0]
# select random slice
random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1]))
output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx]
output_from_past_slice = output_from_past[:, 0, random_slice_idx]
# test that outputs are equal for slice
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-6)
def create_and_check_causal_lm_model_past_with_attn_mask(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
):
config.is_decoder = True
model = TF{{cookiecutter.camelcase_modelname}}ForCausalLM(config=config)
# create attention mask
half_seq_length = self.seq_length // 2
attn_mask_begin = tf.ones((self.batch_size, half_seq_length), dtype=tf.int32)
attn_mask_end = tf.zeros((self.batch_size, self.seq_length - half_seq_length), dtype=tf.int32)
attn_mask = tf.concat([attn_mask_begin, attn_mask_end], axis=1)
# first forward pass
outputs = model(input_ids, attention_mask=attn_mask, use_cache=True)
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
past_key_values = outputs.past_key_values
# change a random masked slice from input_ids
random_seq_idx_to_change = ids_tensor((1,), half_seq_length).numpy() + 1
random_other_next_tokens = ids_tensor((self.batch_size, self.seq_length), config.vocab_size)
vector_condition = tf.range(self.seq_length) == (self.seq_length - random_seq_idx_to_change)
condition = tf.transpose(
tf.broadcast_to(tf.expand_dims(vector_condition, -1), (self.seq_length, self.batch_size))
)
input_ids = tf.where(condition, random_other_next_tokens, input_ids)
# append to next input_ids and
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
attn_mask = tf.concat(
[attn_mask, tf.ones((attn_mask.shape[0], 1), dtype=tf.int32)],
axis=1,
)
output_from_no_past = model(
next_input_ids,
attention_mask=attn_mask,
output_hidden_states=True,
).hidden_states[0]
output_from_past = model(
next_tokens, past_key_values=past_key_values, attention_mask=attn_mask, output_hidden_states=True
).hidden_states[0]
# select random slice
random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1]))
output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx]
output_from_past_slice = output_from_past[:, 0, random_slice_idx]
# test that outputs are equal for slice
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-6)
def create_and_check_causal_lm_model_past_large_inputs(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
):
config.is_decoder = True
model = TF{{cookiecutter.camelcase_modelname}}ForCausalLM(config=config)
input_ids = input_ids[:1, :]
input_mask = input_mask[:1, :]
self.batch_size = 1
# first forward pass
outputs = model(input_ids, attention_mask=input_mask, use_cache=True)
past_key_values = outputs.past_key_values
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
next_attn_mask = ids_tensor((self.batch_size, 3), 2)
# append to next input_ids and
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
next_attention_mask = tf.concat([input_mask, next_attn_mask], axis=-1)
output_from_no_past = model(
next_input_ids,
attention_mask=next_attention_mask,
output_hidden_states=True,
).hidden_states[0]
output_from_past = model(
next_tokens,
attention_mask=next_attention_mask,
past_key_values=past_key_values,
output_hidden_states=True,
).hidden_states[0]
self.parent.assertEqual(next_tokens.shape[1], output_from_past.shape[1])
# select random slice
random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1]))
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx]
output_from_past_slice = output_from_past[:, :, random_slice_idx]
# test that outputs are equal for slice
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3)
def create_and_check_decoder_model_past_large_inputs(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
):
config.add_cross_attention = True
model = TF{{cookiecutter.camelcase_modelname}}ForCausalLM(config=config)
input_ids = input_ids[:1, :]
input_mask = input_mask[:1, :]
encoder_hidden_states = encoder_hidden_states[:1, :, :]
encoder_attention_mask = encoder_attention_mask[:1, :]
self.batch_size = 1
# first forward pass
outputs = model(
input_ids,
attention_mask=input_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=True,
)
past_key_values = outputs.past_key_values
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
next_attn_mask = ids_tensor((self.batch_size, 3), 2)
# append to next input_ids and
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
next_attention_mask = tf.concat([input_mask, next_attn_mask], axis=-1)
output_from_no_past = model(
next_input_ids,
attention_mask=next_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_hidden_states=True,
).hidden_states[0]
output_from_past = model(
next_tokens,
attention_mask=next_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
output_hidden_states=True,
).hidden_states[0]
self.parent.assertEqual(next_tokens.shape[1], output_from_past.shape[1])
# select random slice
random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1]))
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx]
output_from_past_slice = output_from_past[:, :, random_slice_idx]
# test that outputs are equal for slice
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3)
def create_and_check_for_masked_lm(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
@ -290,16 +593,59 @@ class TF{{cookiecutter.camelcase_modelname}}ModelTest(TFModelTesterMixin, unitte
self.config_tester.run_common_tests()
def test_model(self):
"""Test the base model"""
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_causal_lm_base_model(self):
"""Test the base model of the causal LM model
is_deocder=True, no cross_attention, no encoder outputs
"""
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_causal_lm_base_model(*config_and_inputs)
def test_model_as_decoder(self):
"""Test the base model as a decoder (of an encoder-decoder architecture)
is_deocder=True + cross_attention + pass encoder outputs
"""
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
self.model_tester.create_and_check_model_as_decoder(*config_and_inputs)
def test_for_masked_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
def test_for_causal_lm(self):
"""Test the causal LM model"""
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_lm_head(*config_and_inputs)
self.model_tester.create_and_check_causal_lm_model(*config_and_inputs)
def test_causal_lm_model_as_decoder(self):
"""Test the causal LM model as a decoder"""
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
self.model_tester.create_and_check_causal_lm_model_as_decoder(*config_and_inputs)
def test_causal_lm_model_past(self):
"""Test causal LM model with `past_key_values`"""
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_causal_lm_model_past(*config_and_inputs)
def test_causal_lm_model_past_with_attn_mask(self):
"""Test the causal LM model with `past_key_values` and `attention_mask`"""
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_causal_lm_model_past_with_attn_mask(*config_and_inputs)
def test_causal_lm_model_past_with_large_inputs(self):
"""Test the causal LM model with `past_key_values` and a longer decoder sequence length"""
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_causal_lm_model_past_large_inputs(*config_and_inputs)
def test_decoder_model_past_with_large_inputs(self):
"""Similar to `test_causal_lm_model_past_with_large_inputs` but with cross-attention"""
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
def test_for_multiple_choice(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
@ -460,7 +806,7 @@ class TF{{cookiecutter.camelcase_modelname}}ModelTester:
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
next_attn_mask = tf.cast(ids_tensor((self.batch_size, 3), 2), tf.int8)
next_attn_mask = ids_tensor((self.batch_size, 3), 2)
# append to next input_ids and
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
@ -488,9 +834,9 @@ def prepare_{{cookiecutter.lowercase_modelname}}_inputs_dict(
decoder_attention_mask=None,
):
if attention_mask is None:
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int32)
if decoder_attention_mask is None:
decoder_attention_mask = tf.concat([tf.ones(decoder_input_ids[:, :1].shape, dtype=tf.int8), tf.cast(tf.math.not_equal(decoder_input_ids[:, 1:], config.pad_token_id), tf.int8)], axis=-1)
decoder_attention_mask = tf.concat([tf.ones(decoder_input_ids[:, :1].shape, dtype=tf.int32), tf.cast(tf.math.not_equal(decoder_input_ids[:, 1:], config.pad_token_id), tf.int32)], axis=-1)
return {
"input_ids": input_ids,
"decoder_input_ids": decoder_input_ids,

View File

@ -153,12 +153,12 @@ class TFBertModelTester:
encoder_attention_mask,
)
def create_and_check_bert_model(
def create_and_check_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = TFBertModel(config=config)
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
sequence_output, pooled_output = model(inputs)
result = model(inputs)
inputs = [input_ids, input_mask]
result = model(inputs)
@ -168,10 +168,61 @@ class TFBertModelTester:
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
def create_and_check_bert_lm_head(
def create_and_check_causal_lm_base_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.is_decoder = True
model = TFBertModel(config=config)
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
result = model(inputs)
inputs = [input_ids, input_mask]
result = model(inputs)
result = model(input_ids)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
def create_and_check_model_as_decoder(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
):
config.add_cross_attention = True
model = TFBertModel(config=config)
inputs = {
"input_ids": input_ids,
"attention_mask": input_mask,
"token_type_ids": token_type_ids,
"encoder_hidden_states": encoder_hidden_states,
"encoder_attention_mask": encoder_attention_mask,
}
result = model(inputs)
inputs = [input_ids, input_mask]
result = model(inputs, token_type_ids=token_type_ids, encoder_hidden_states=encoder_hidden_states)
# Also check the case where encoder outputs are not passed
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
def create_and_check_causal_lm_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.is_decoder = True
model = TFBertLMHeadModel(config=config)
inputs = {
"input_ids": input_ids,
@ -183,7 +234,261 @@ class TFBertModelTester:
list(prediction_scores.numpy().shape), [self.batch_size, self.seq_length, self.vocab_size]
)
def create_and_check_bert_for_masked_lm(
def create_and_check_causal_lm_model_as_decoder(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
):
config.add_cross_attention = True
model = TFBertLMHeadModel(config=config)
inputs = {
"input_ids": input_ids,
"attention_mask": input_mask,
"token_type_ids": token_type_ids,
"encoder_hidden_states": encoder_hidden_states,
"encoder_attention_mask": encoder_attention_mask,
}
result = model(inputs)
inputs = [input_ids, input_mask]
result = model(inputs, token_type_ids=token_type_ids, encoder_hidden_states=encoder_hidden_states)
prediction_scores = result["logits"]
self.parent.assertListEqual(
list(prediction_scores.numpy().shape), [self.batch_size, self.seq_length, self.vocab_size]
)
def create_and_check_causal_lm_model_past(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
):
config.is_decoder = True
model = TFBertLMHeadModel(config=config)
# first forward pass
outputs = model(input_ids, use_cache=True)
outputs_use_cache_conf = model(input_ids)
outputs_no_past = model(input_ids, use_cache=False)
self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
past_key_values = outputs.past_key_values
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
# append to next input_ids and attn_mask
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
output_from_no_past = model(next_input_ids, output_hidden_states=True).hidden_states[0]
output_from_past = model(
next_tokens, past_key_values=past_key_values, output_hidden_states=True
).hidden_states[0]
# select random slice
random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1]))
output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx]
output_from_past_slice = output_from_past[:, 0, random_slice_idx]
# test that outputs are equal for slice
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-6)
def create_and_check_causal_lm_model_past_with_attn_mask(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
):
config.is_decoder = True
model = TFBertLMHeadModel(config=config)
# create attention mask
half_seq_length = self.seq_length // 2
attn_mask_begin = tf.ones((self.batch_size, half_seq_length), dtype=tf.int32)
attn_mask_end = tf.zeros((self.batch_size, self.seq_length - half_seq_length), dtype=tf.int32)
attn_mask = tf.concat([attn_mask_begin, attn_mask_end], axis=1)
# first forward pass
outputs = model(input_ids, attention_mask=attn_mask, use_cache=True)
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
past_key_values = outputs.past_key_values
# change a random masked slice from input_ids
random_seq_idx_to_change = ids_tensor((1,), half_seq_length).numpy() + 1
random_other_next_tokens = ids_tensor((self.batch_size, self.seq_length), config.vocab_size)
vector_condition = tf.range(self.seq_length) == (self.seq_length - random_seq_idx_to_change)
condition = tf.transpose(
tf.broadcast_to(tf.expand_dims(vector_condition, -1), (self.seq_length, self.batch_size))
)
input_ids = tf.where(condition, random_other_next_tokens, input_ids)
# append to next input_ids and
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
attn_mask = tf.concat(
[attn_mask, tf.ones((attn_mask.shape[0], 1), dtype=tf.int32)],
axis=1,
)
output_from_no_past = model(
next_input_ids,
attention_mask=attn_mask,
output_hidden_states=True,
).hidden_states[0]
output_from_past = model(
next_tokens, past_key_values=past_key_values, attention_mask=attn_mask, output_hidden_states=True
).hidden_states[0]
# select random slice
random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1]))
output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx]
output_from_past_slice = output_from_past[:, 0, random_slice_idx]
# test that outputs are equal for slice
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-6)
def create_and_check_causal_lm_model_past_large_inputs(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
):
config.is_decoder = True
model = TFBertLMHeadModel(config=config)
input_ids = input_ids[:1, :]
input_mask = input_mask[:1, :]
self.batch_size = 1
# first forward pass
outputs = model(input_ids, attention_mask=input_mask, use_cache=True)
past_key_values = outputs.past_key_values
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
next_attn_mask = ids_tensor((self.batch_size, 3), 2)
# append to next input_ids and
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
next_attention_mask = tf.concat([input_mask, next_attn_mask], axis=-1)
output_from_no_past = model(
next_input_ids,
attention_mask=next_attention_mask,
output_hidden_states=True,
).hidden_states[0]
output_from_past = model(
next_tokens,
attention_mask=next_attention_mask,
past_key_values=past_key_values,
output_hidden_states=True,
).hidden_states[0]
self.parent.assertEqual(next_tokens.shape[1], output_from_past.shape[1])
# select random slice
random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1]))
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx]
output_from_past_slice = output_from_past[:, :, random_slice_idx]
# test that outputs are equal for slice
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3)
def create_and_check_decoder_model_past_large_inputs(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
):
config.add_cross_attention = True
model = TFBertLMHeadModel(config=config)
input_ids = input_ids[:1, :]
input_mask = input_mask[:1, :]
encoder_hidden_states = encoder_hidden_states[:1, :, :]
encoder_attention_mask = encoder_attention_mask[:1, :]
self.batch_size = 1
# first forward pass
outputs = model(
input_ids,
attention_mask=input_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=True,
)
past_key_values = outputs.past_key_values
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
next_attn_mask = ids_tensor((self.batch_size, 3), 2)
# append to next input_ids and
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
next_attention_mask = tf.concat([input_mask, next_attn_mask], axis=-1)
output_from_no_past = model(
next_input_ids,
attention_mask=next_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_hidden_states=True,
).hidden_states[0]
output_from_past = model(
next_tokens,
attention_mask=next_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
output_hidden_states=True,
).hidden_states[0]
self.parent.assertEqual(next_tokens.shape[1], output_from_past.shape[1])
# select random slice
random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1]))
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx]
output_from_past_slice = output_from_past[:, :, random_slice_idx]
# test that outputs are equal for slice
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3)
def create_and_check_for_masked_lm(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = TFBertForMaskedLM(config=config)
@ -195,7 +500,7 @@ class TFBertModelTester:
result = model(inputs)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_bert_for_next_sequence_prediction(
def create_and_check_for_next_sequence_prediction(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = TFBertForNextSentencePrediction(config=config)
@ -203,7 +508,7 @@ class TFBertModelTester:
result = model(inputs)
self.parent.assertEqual(result.logits.shape, (self.batch_size, 2))
def create_and_check_bert_for_pretraining(
def create_and_check_for_pretraining(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = TFBertForPreTraining(config=config)
@ -212,7 +517,7 @@ class TFBertModelTester:
self.parent.assertEqual(result.prediction_logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
self.parent.assertEqual(result.seq_relationship_logits.shape, (self.batch_size, 2))
def create_and_check_bert_for_sequence_classification(
def create_and_check_for_sequence_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_labels = self.num_labels
@ -226,7 +531,7 @@ class TFBertModelTester:
result = model(inputs)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
def create_and_check_bert_for_multiple_choice(
def create_and_check_for_multiple_choice(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_choices = self.num_choices
@ -242,7 +547,7 @@ class TFBertModelTester:
result = model(inputs)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
def create_and_check_bert_for_token_classification(
def create_and_check_for_token_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_labels = self.num_labels
@ -255,7 +560,7 @@ class TFBertModelTester:
result = model(inputs)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
def create_and_check_bert_for_question_answering(
def create_and_check_for_question_answering(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = TFBertForQuestionAnswering(config=config)
@ -323,41 +628,84 @@ class TFBertModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, unittest.TestC
def test_config(self):
self.config_tester.run_common_tests()
def test_bert_model(self):
def test_model(self):
"""Test the base model"""
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_model(*config_and_inputs)
self.model_tester.create_and_check_model(*config_and_inputs)
def test_causal_lm_base_model(self):
"""Test the base model of the causal LM model
is_deocder=True, no cross_attention, no encoder outputs
"""
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_causal_lm_base_model(*config_and_inputs)
def test_model_as_decoder(self):
"""Test the base model as a decoder (of an encoder-decoder architecture)
is_deocder=True + cross_attention + pass encoder outputs
"""
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
self.model_tester.create_and_check_model_as_decoder(*config_and_inputs)
def test_for_masked_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_masked_lm(*config_and_inputs)
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
def test_for_causal_lm(self):
"""Test the causal LM model"""
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_lm_head(*config_and_inputs)
self.model_tester.create_and_check_causal_lm_model(*config_and_inputs)
def test_causal_lm_model_as_decoder(self):
"""Test the causal LM model as a decoder"""
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
self.model_tester.create_and_check_causal_lm_model_as_decoder(*config_and_inputs)
def test_causal_lm_model_past(self):
"""Test causal LM model with `past_key_values`"""
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_causal_lm_model_past(*config_and_inputs)
def test_causal_lm_model_past_with_attn_mask(self):
"""Test the causal LM model with `past_key_values` and `attention_mask`"""
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_causal_lm_model_past_with_attn_mask(*config_and_inputs)
def test_causal_lm_model_past_with_large_inputs(self):
"""Test the causal LM model with `past_key_values` and a longer decoder sequence length"""
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_causal_lm_model_past_large_inputs(*config_and_inputs)
def test_decoder_model_past_with_large_inputs(self):
"""Similar to `test_causal_lm_model_past_with_large_inputs` but with cross-attention"""
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
def test_for_multiple_choice(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_multiple_choice(*config_and_inputs)
self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)
def test_for_next_sequence_prediction(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_next_sequence_prediction(*config_and_inputs)
self.model_tester.create_and_check_for_next_sequence_prediction(*config_and_inputs)
def test_for_pretraining(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_pretraining(*config_and_inputs)
self.model_tester.create_and_check_for_pretraining(*config_and_inputs)
def test_for_question_answering(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_question_answering(*config_and_inputs)
self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
def test_for_sequence_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_sequence_classification(*config_and_inputs)
self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
def test_for_token_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_token_classification(*config_and_inputs)
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
def test_model_from_pretrained(self):
model = TFBertModel.from_pretrained("jplu/tiny-tf-bert-random")

View File

@ -20,7 +20,7 @@ from transformers import ElectraConfig, is_tf_available
from transformers.testing_utils import require_tf, slow
from ..test_configuration_common import ConfigTester
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor
from ..test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
if is_tf_available():
@ -101,7 +101,34 @@ class TFElectraModelTester:
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def create_and_check_electra_model(
def prepare_config_and_inputs_for_decoder(self):
(
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
) = self.prepare_config_and_inputs()
config.is_decoder = True
encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
return (
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
)
def create_and_check_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = TFElectraModel(config=config)
@ -115,7 +142,277 @@ class TFElectraModelTester:
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
def create_and_check_electra_for_masked_lm(
def create_and_check_causal_lm_base_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.is_decoder = True
model = TFElectraModel(config=config)
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
result = model(inputs)
inputs = [input_ids, input_mask]
result = model(inputs)
result = model(input_ids)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
def create_and_check_model_as_decoder(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
):
config.add_cross_attention = True
model = TFElectraModel(config=config)
inputs = {
"input_ids": input_ids,
"attention_mask": input_mask,
"token_type_ids": token_type_ids,
"encoder_hidden_states": encoder_hidden_states,
"encoder_attention_mask": encoder_attention_mask,
}
result = model(inputs)
inputs = [input_ids, input_mask]
result = model(inputs, token_type_ids=token_type_ids, encoder_hidden_states=encoder_hidden_states)
# Also check the case where encoder outputs are not passed
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
def create_and_check_causal_lm_base_model_past(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
):
config.is_decoder = True
model = TFElectraModel(config=config)
# first forward pass
outputs = model(input_ids, use_cache=True)
outputs_use_cache_conf = model(input_ids)
outputs_no_past = model(input_ids, use_cache=False)
self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
past_key_values = outputs.past_key_values
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
# append to next input_ids and attn_mask
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
output_from_no_past = model(next_input_ids, output_hidden_states=True).hidden_states[0]
output_from_past = model(
next_tokens, past_key_values=past_key_values, output_hidden_states=True
).hidden_states[0]
# select random slice
random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1]))
output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx]
output_from_past_slice = output_from_past[:, 0, random_slice_idx]
# test that outputs are equal for slice
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-6)
def create_and_check_causal_lm_base_model_past_with_attn_mask(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
):
config.is_decoder = True
model = TFElectraModel(config=config)
# create attention mask
half_seq_length = self.seq_length // 2
attn_mask_begin = tf.ones((self.batch_size, half_seq_length), dtype=tf.int32)
attn_mask_end = tf.zeros((self.batch_size, self.seq_length - half_seq_length), dtype=tf.int32)
attn_mask = tf.concat([attn_mask_begin, attn_mask_end], axis=1)
# first forward pass
outputs = model(input_ids, attention_mask=attn_mask, use_cache=True)
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
past_key_values = outputs.past_key_values
# change a random masked slice from input_ids
random_seq_idx_to_change = ids_tensor((1,), half_seq_length).numpy() + 1
random_other_next_tokens = ids_tensor((self.batch_size, self.seq_length), config.vocab_size)
vector_condition = tf.range(self.seq_length) == (self.seq_length - random_seq_idx_to_change)
condition = tf.transpose(
tf.broadcast_to(tf.expand_dims(vector_condition, -1), (self.seq_length, self.batch_size))
)
input_ids = tf.where(condition, random_other_next_tokens, input_ids)
# append to next input_ids and
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
attn_mask = tf.concat(
[attn_mask, tf.ones((attn_mask.shape[0], 1), dtype=tf.int32)],
axis=1,
)
output_from_no_past = model(
next_input_ids,
attention_mask=attn_mask,
output_hidden_states=True,
).hidden_states[0]
output_from_past = model(
next_tokens, past_key_values=past_key_values, attention_mask=attn_mask, output_hidden_states=True
).hidden_states[0]
# select random slice
random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1]))
output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx]
output_from_past_slice = output_from_past[:, 0, random_slice_idx]
# test that outputs are equal for slice
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-6)
def create_and_check_causal_lm_base_model_past_large_inputs(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
):
config.is_decoder = True
model = TFElectraModel(config=config)
input_ids = input_ids[:1, :]
input_mask = input_mask[:1, :]
self.batch_size = 1
# first forward pass
outputs = model(input_ids, attention_mask=input_mask, use_cache=True)
past_key_values = outputs.past_key_values
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
next_attn_mask = ids_tensor((self.batch_size, 3), 2)
# append to next input_ids and
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
next_attention_mask = tf.concat([input_mask, next_attn_mask], axis=-1)
output_from_no_past = model(
next_input_ids,
attention_mask=next_attention_mask,
output_hidden_states=True,
).hidden_states[0]
output_from_past = model(
next_tokens,
attention_mask=next_attention_mask,
past_key_values=past_key_values,
output_hidden_states=True,
).hidden_states[0]
self.parent.assertEqual(next_tokens.shape[1], output_from_past.shape[1])
# select random slice
random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1]))
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx]
output_from_past_slice = output_from_past[:, :, random_slice_idx]
# test that outputs are equal for slice
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3)
def create_and_check_decoder_model_past_large_inputs(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
):
config.add_cross_attention = True
model = TFElectraModel(config=config)
input_ids = input_ids[:1, :]
input_mask = input_mask[:1, :]
encoder_hidden_states = encoder_hidden_states[:1, :, :]
encoder_attention_mask = encoder_attention_mask[:1, :]
self.batch_size = 1
# first forward pass
outputs = model(
input_ids,
attention_mask=input_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=True,
)
past_key_values = outputs.past_key_values
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
next_attn_mask = ids_tensor((self.batch_size, 3), 2)
# append to next input_ids and
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
next_attention_mask = tf.concat([input_mask, next_attn_mask], axis=-1)
output_from_no_past = model(
next_input_ids,
attention_mask=next_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_hidden_states=True,
).hidden_states[0]
output_from_past = model(
next_tokens,
attention_mask=next_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
output_hidden_states=True,
).hidden_states[0]
self.parent.assertEqual(next_tokens.shape[1], output_from_past.shape[1])
# select random slice
random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1]))
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx]
output_from_past_slice = output_from_past[:, :, random_slice_idx]
# test that outputs are equal for slice
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3)
def create_and_check_for_masked_lm(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = TFElectraForMaskedLM(config=config)
@ -123,7 +420,7 @@ class TFElectraModelTester:
result = model(inputs)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_electra_for_pretraining(
def create_and_check_for_pretraining(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = TFElectraForPreTraining(config=config)
@ -131,7 +428,7 @@ class TFElectraModelTester:
result = model(inputs)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length))
def create_and_check_electra_for_sequence_classification(
def create_and_check_for_sequence_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_labels = self.num_labels
@ -140,7 +437,7 @@ class TFElectraModelTester:
result = model(inputs)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
def create_and_check_electra_for_multiple_choice(
def create_and_check_for_multiple_choice(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_choices = self.num_choices
@ -156,7 +453,7 @@ class TFElectraModelTester:
result = model(inputs)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
def create_and_check_electra_for_question_answering(
def create_and_check_for_question_answering(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = TFElectraForQuestionAnswering(config=config)
@ -165,7 +462,7 @@ class TFElectraModelTester:
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
def create_and_check_electra_for_token_classification(
def create_and_check_for_token_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_labels = self.num_labels
@ -215,33 +512,70 @@ class TFElectraModelTest(TFModelTesterMixin, unittest.TestCase):
def test_config(self):
self.config_tester.run_common_tests()
def test_electra_model(self):
def test_model(self):
"""Test the base model"""
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_electra_model(*config_and_inputs)
self.model_tester.create_and_check_model(*config_and_inputs)
def test_causal_lm_base_model(self):
"""Test the base model of the causal LM model
is_deocder=True, no cross_attention, no encoder outputs
"""
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_causal_lm_base_model(*config_and_inputs)
def test_model_as_decoder(self):
"""Test the base model as a decoder (of an encoder-decoder architecture)
is_deocder=True + cross_attention + pass encoder outputs
"""
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
self.model_tester.create_and_check_model_as_decoder(*config_and_inputs)
def test_causal_lm_base_model_past(self):
"""Test causal LM base model with `past_key_values`"""
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_causal_lm_base_model_past(*config_and_inputs)
def test_causal_lm_base_model_past_with_attn_mask(self):
"""Test the causal LM base model with `past_key_values` and `attention_mask`"""
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_causal_lm_base_model_past_with_attn_mask(*config_and_inputs)
def test_causal_lm_base_model_past_with_large_inputs(self):
"""Test the causal LM base model with `past_key_values` and a longer decoder sequence length"""
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_causal_lm_base_model_past_large_inputs(*config_and_inputs)
def test_decoder_model_past_with_large_inputs(self):
"""Similar to `test_causal_lm_base_model_past_with_large_inputs` but with cross-attention"""
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
def test_for_masked_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_electra_for_masked_lm(*config_and_inputs)
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
def test_for_pretraining(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_electra_for_pretraining(*config_and_inputs)
self.model_tester.create_and_check_for_pretraining(*config_and_inputs)
def test_for_question_answering(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_electra_for_question_answering(*config_and_inputs)
self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
def test_for_sequence_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_electra_for_sequence_classification(*config_and_inputs)
self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
def test_for_multiple_choice(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_electra_for_multiple_choice(*config_and_inputs)
self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)
def test_for_token_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_electra_for_token_classification(*config_and_inputs)
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
@slow
def test_model_from_pretrained(self):

View File

@ -168,7 +168,55 @@ class TFRemBertModelTester:
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
def create_and_check_lm_head(
def create_and_check_causal_lm_base_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.is_decoder = True
model = TFRemBertModel(config=config)
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
result = model(inputs)
inputs = [input_ids, input_mask]
result = model(inputs)
result = model(input_ids)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
def create_and_check_model_as_decoder(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
):
config.add_cross_attention = True
model = TFRemBertModel(config=config)
inputs = {
"input_ids": input_ids,
"attention_mask": input_mask,
"token_type_ids": token_type_ids,
"encoder_hidden_states": encoder_hidden_states,
"encoder_attention_mask": encoder_attention_mask,
}
result = model(inputs)
inputs = [input_ids, input_mask]
result = model(inputs, token_type_ids=token_type_ids, encoder_hidden_states=encoder_hidden_states)
# Also check the case where encoder outputs are not passed
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
def create_and_check_causal_lm_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.is_decoder = True
@ -183,6 +231,260 @@ class TFRemBertModelTester:
list(prediction_scores.numpy().shape), [self.batch_size, self.seq_length, self.vocab_size]
)
def create_and_check_causal_lm_model_as_decoder(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
):
config.add_cross_attention = True
model = TFRemBertForCausalLM(config=config)
inputs = {
"input_ids": input_ids,
"attention_mask": input_mask,
"token_type_ids": token_type_ids,
"encoder_hidden_states": encoder_hidden_states,
"encoder_attention_mask": encoder_attention_mask,
}
result = model(inputs)
inputs = [input_ids, input_mask]
result = model(inputs, token_type_ids=token_type_ids, encoder_hidden_states=encoder_hidden_states)
prediction_scores = result["logits"]
self.parent.assertListEqual(
list(prediction_scores.numpy().shape), [self.batch_size, self.seq_length, self.vocab_size]
)
def create_and_check_causal_lm_model_past(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
):
config.is_decoder = True
model = TFRemBertForCausalLM(config=config)
# first forward pass
outputs = model(input_ids, use_cache=True)
outputs_use_cache_conf = model(input_ids)
outputs_no_past = model(input_ids, use_cache=False)
self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
past_key_values = outputs.past_key_values
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
# append to next input_ids and attn_mask
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
output_from_no_past = model(next_input_ids, output_hidden_states=True).hidden_states[0]
output_from_past = model(
next_tokens, past_key_values=past_key_values, output_hidden_states=True
).hidden_states[0]
# select random slice
random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1]))
output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx]
output_from_past_slice = output_from_past[:, 0, random_slice_idx]
# test that outputs are equal for slice
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-6)
def create_and_check_causal_lm_model_past_with_attn_mask(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
):
config.is_decoder = True
model = TFRemBertForCausalLM(config=config)
# create attention mask
half_seq_length = self.seq_length // 2
attn_mask_begin = tf.ones((self.batch_size, half_seq_length), dtype=tf.int32)
attn_mask_end = tf.zeros((self.batch_size, self.seq_length - half_seq_length), dtype=tf.int32)
attn_mask = tf.concat([attn_mask_begin, attn_mask_end], axis=1)
# first forward pass
outputs = model(input_ids, attention_mask=attn_mask, use_cache=True)
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
past_key_values = outputs.past_key_values
# change a random masked slice from input_ids
random_seq_idx_to_change = ids_tensor((1,), half_seq_length).numpy() + 1
random_other_next_tokens = ids_tensor((self.batch_size, self.seq_length), config.vocab_size)
vector_condition = tf.range(self.seq_length) == (self.seq_length - random_seq_idx_to_change)
condition = tf.transpose(
tf.broadcast_to(tf.expand_dims(vector_condition, -1), (self.seq_length, self.batch_size))
)
input_ids = tf.where(condition, random_other_next_tokens, input_ids)
# append to next input_ids and
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
attn_mask = tf.concat(
[attn_mask, tf.ones((attn_mask.shape[0], 1), dtype=tf.int32)],
axis=1,
)
output_from_no_past = model(
next_input_ids,
attention_mask=attn_mask,
output_hidden_states=True,
).hidden_states[0]
output_from_past = model(
next_tokens, past_key_values=past_key_values, attention_mask=attn_mask, output_hidden_states=True
).hidden_states[0]
# select random slice
random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1]))
output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx]
output_from_past_slice = output_from_past[:, 0, random_slice_idx]
# test that outputs are equal for slice
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-6)
def create_and_check_causal_lm_model_past_large_inputs(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
):
config.is_decoder = True
model = TFRemBertForCausalLM(config=config)
input_ids = input_ids[:1, :]
input_mask = input_mask[:1, :]
self.batch_size = 1
# first forward pass
outputs = model(input_ids, attention_mask=input_mask, use_cache=True)
past_key_values = outputs.past_key_values
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
next_attn_mask = ids_tensor((self.batch_size, 3), 2)
# append to next input_ids and
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
next_attention_mask = tf.concat([input_mask, next_attn_mask], axis=-1)
output_from_no_past = model(
next_input_ids,
attention_mask=next_attention_mask,
output_hidden_states=True,
).hidden_states[0]
output_from_past = model(
next_tokens,
attention_mask=next_attention_mask,
past_key_values=past_key_values,
output_hidden_states=True,
).hidden_states[0]
self.parent.assertEqual(next_tokens.shape[1], output_from_past.shape[1])
# select random slice
random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1]))
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx]
output_from_past_slice = output_from_past[:, :, random_slice_idx]
# test that outputs are equal for slice
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3)
def create_and_check_decoder_model_past_large_inputs(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
):
config.add_cross_attention = True
model = TFRemBertForCausalLM(config=config)
input_ids = input_ids[:1, :]
input_mask = input_mask[:1, :]
encoder_hidden_states = encoder_hidden_states[:1, :, :]
encoder_attention_mask = encoder_attention_mask[:1, :]
self.batch_size = 1
# first forward pass
outputs = model(
input_ids,
attention_mask=input_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=True,
)
past_key_values = outputs.past_key_values
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
next_attn_mask = ids_tensor((self.batch_size, 3), 2)
# append to next input_ids and
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
next_attention_mask = tf.concat([input_mask, next_attn_mask], axis=-1)
output_from_no_past = model(
next_input_ids,
attention_mask=next_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_hidden_states=True,
).hidden_states[0]
output_from_past = model(
next_tokens,
attention_mask=next_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
output_hidden_states=True,
).hidden_states[0]
self.parent.assertEqual(next_tokens.shape[1], output_from_past.shape[1])
# select random slice
random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1]))
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx]
output_from_past_slice = output_from_past[:, :, random_slice_idx]
# test that outputs are equal for slice
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3)
def create_and_check_for_masked_lm(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
@ -295,16 +597,59 @@ class TFRemBertModelTest(TFModelTesterMixin, unittest.TestCase):
self.config_tester.run_common_tests()
def test_model(self):
"""Test the base model"""
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_causal_lm_base_model(self):
"""Test the base model of the causal LM model
is_deocder=True, no cross_attention, no encoder outputs
"""
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_causal_lm_base_model(*config_and_inputs)
def test_model_as_decoder(self):
"""Test the base model as a decoder (of an encoder-decoder architecture)
is_deocder=True + cross_attention + pass encoder outputs
"""
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
self.model_tester.create_and_check_model_as_decoder(*config_and_inputs)
def test_for_masked_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
def test_for_causal_lm(self):
"""Test the causal LM model"""
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_lm_head(*config_and_inputs)
self.model_tester.create_and_check_causal_lm_model(*config_and_inputs)
def test_causal_lm_model_as_decoder(self):
"""Test the causal LM model as a decoder"""
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
self.model_tester.create_and_check_causal_lm_model_as_decoder(*config_and_inputs)
def test_causal_lm_model_past(self):
"""Test causal LM model with `past_key_values`"""
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_causal_lm_model_past(*config_and_inputs)
def test_causal_lm_model_past_with_attn_mask(self):
"""Test the causal LM model with `past_key_values` and `attention_mask`"""
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_causal_lm_model_past_with_attn_mask(*config_and_inputs)
def test_causal_lm_model_past_with_large_inputs(self):
"""Test the causal LM model with `past_key_values` and a longer decoder sequence length"""
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_causal_lm_model_past_large_inputs(*config_and_inputs)
def test_decoder_model_past_with_large_inputs(self):
"""Similar to `test_causal_lm_model_past_with_large_inputs` but with cross-attention"""
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
def test_for_multiple_choice(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()

View File

@ -129,7 +129,7 @@ class TFRobertaModelTester:
encoder_attention_mask,
)
def create_and_check_roberta_model(
def create_and_check_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = TFRobertaModel(config=config)
@ -143,21 +143,361 @@ class TFRobertaModelTester:
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
def create_and_check_roberta_for_causal_lm(
def create_and_check_causal_lm_base_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = TFRobertaForCausalLM(config=config)
result = model([input_ids, input_mask, token_type_ids])
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
config.is_decoder = True
def create_and_check_roberta_for_masked_lm(
model = TFRobertaModel(config=config)
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
result = model(inputs)
inputs = [input_ids, input_mask]
result = model(inputs)
result = model(input_ids)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
def create_and_check_model_as_decoder(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
):
config.add_cross_attention = True
model = TFRobertaModel(config=config)
inputs = {
"input_ids": input_ids,
"attention_mask": input_mask,
"token_type_ids": token_type_ids,
"encoder_hidden_states": encoder_hidden_states,
"encoder_attention_mask": encoder_attention_mask,
}
result = model(inputs)
inputs = [input_ids, input_mask]
result = model(inputs, token_type_ids=token_type_ids, encoder_hidden_states=encoder_hidden_states)
# Also check the case where encoder outputs are not passed
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
def create_and_check_causal_lm_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.is_decoder = True
model = TFRobertaForCausalLM(config=config)
inputs = {
"input_ids": input_ids,
"attention_mask": input_mask,
"token_type_ids": token_type_ids,
}
prediction_scores = model(inputs)["logits"]
self.parent.assertListEqual(
list(prediction_scores.numpy().shape), [self.batch_size, self.seq_length, self.vocab_size]
)
def create_and_check_causal_lm_model_as_decoder(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
):
config.add_cross_attention = True
model = TFRobertaForCausalLM(config=config)
inputs = {
"input_ids": input_ids,
"attention_mask": input_mask,
"token_type_ids": token_type_ids,
"encoder_hidden_states": encoder_hidden_states,
"encoder_attention_mask": encoder_attention_mask,
}
result = model(inputs)
inputs = [input_ids, input_mask]
result = model(inputs, token_type_ids=token_type_ids, encoder_hidden_states=encoder_hidden_states)
prediction_scores = result["logits"]
self.parent.assertListEqual(
list(prediction_scores.numpy().shape), [self.batch_size, self.seq_length, self.vocab_size]
)
def create_and_check_causal_lm_model_past(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
):
config.is_decoder = True
model = TFRobertaForCausalLM(config=config)
# special to `RobertaEmbeddings` in `Roberta`:
# - its `padding_idx` and its effect on `position_ids`
# (TFRobertaEmbeddings.create_position_ids_from_input_ids)
# - `1` here is `TFRobertaEmbeddings.padding_idx`
input_ids = tf.where(input_ids == 1, 2, input_ids)
# first forward pass
outputs = model(input_ids, use_cache=True)
outputs_use_cache_conf = model(input_ids)
outputs_no_past = model(input_ids, use_cache=False)
self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
past_key_values = outputs.past_key_values
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
# append to next input_ids and attn_mask
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
output_from_no_past = model(next_input_ids, output_hidden_states=True).hidden_states[0]
output_from_past = model(
next_tokens, past_key_values=past_key_values, output_hidden_states=True
).hidden_states[0]
# select random slice
random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1]))
output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx]
output_from_past_slice = output_from_past[:, 0, random_slice_idx]
# test that outputs are equal for slice
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-6)
def create_and_check_causal_lm_model_past_with_attn_mask(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
):
config.is_decoder = True
model = TFRobertaForCausalLM(config=config)
# special to `RobertaEmbeddings` in `Roberta`:
# - its `padding_idx` and its effect on `position_ids`
# (TFRobertaEmbeddings.create_position_ids_from_input_ids)
# - `1` here is `TFRobertaEmbeddings.padding_idx`
# avoid `padding_idx` in the past
input_ids = tf.where(input_ids == 1, 2, input_ids)
# create attention mask
half_seq_length = self.seq_length // 2
attn_mask_begin = tf.ones((self.batch_size, half_seq_length), dtype=tf.int32)
attn_mask_end = tf.zeros((self.batch_size, self.seq_length - half_seq_length), dtype=tf.int32)
attn_mask = tf.concat([attn_mask_begin, attn_mask_end], axis=1)
# first forward pass
outputs = model(input_ids, attention_mask=attn_mask, use_cache=True)
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
past_key_values = outputs.past_key_values
# change a random masked slice from input_ids
random_seq_idx_to_change = ids_tensor((1,), half_seq_length).numpy() + 1
random_other_next_tokens = ids_tensor((self.batch_size, self.seq_length), config.vocab_size)
vector_condition = tf.range(self.seq_length) == (self.seq_length - random_seq_idx_to_change)
condition = tf.transpose(
tf.broadcast_to(tf.expand_dims(vector_condition, -1), (self.seq_length, self.batch_size))
)
input_ids = tf.where(condition, random_other_next_tokens, input_ids)
# avoid `padding_idx` in the past
input_ids = tf.where(input_ids == 1, 2, input_ids)
# append to next input_ids and
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
attn_mask = tf.concat(
[attn_mask, tf.ones((attn_mask.shape[0], 1), dtype=tf.int32)],
axis=1,
)
output_from_no_past = model(
next_input_ids,
attention_mask=attn_mask,
output_hidden_states=True,
).hidden_states[0]
output_from_past = model(
next_tokens, past_key_values=past_key_values, attention_mask=attn_mask, output_hidden_states=True
).hidden_states[0]
# select random slice
random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1]))
output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx]
output_from_past_slice = output_from_past[:, 0, random_slice_idx]
# test that outputs are equal for slice
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-6)
def create_and_check_causal_lm_model_past_large_inputs(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
):
config.is_decoder = True
model = TFRobertaForCausalLM(config=config)
# special to `RobertaEmbeddings` in `Roberta`:
# - its `padding_idx` and its effect on `position_ids`
# (TFRobertaEmbeddings.create_position_ids_from_input_ids)
# - `1` here is `TFRobertaEmbeddings.padding_idx`
# avoid `padding_idx` in the past
input_ids = tf.where(input_ids == 1, 2, input_ids)
input_ids = input_ids[:1, :]
input_mask = input_mask[:1, :]
self.batch_size = 1
# first forward pass
outputs = model(input_ids, attention_mask=input_mask, use_cache=True)
past_key_values = outputs.past_key_values
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
next_attn_mask = ids_tensor((self.batch_size, 3), 2)
# append to next input_ids and
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
next_attention_mask = tf.concat([input_mask, next_attn_mask], axis=-1)
output_from_no_past = model(
next_input_ids,
attention_mask=next_attention_mask,
output_hidden_states=True,
).hidden_states[0]
output_from_past = model(
next_tokens,
attention_mask=next_attention_mask,
past_key_values=past_key_values,
output_hidden_states=True,
).hidden_states[0]
self.parent.assertEqual(next_tokens.shape[1], output_from_past.shape[1])
# select random slice
random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1]))
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx]
output_from_past_slice = output_from_past[:, :, random_slice_idx]
# test that outputs are equal for slice
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3)
def create_and_check_decoder_model_past_large_inputs(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
):
config.add_cross_attention = True
model = TFRobertaForCausalLM(config=config)
# special to `RobertaEmbeddings` in `Roberta`:
# - its `padding_idx` and its effect on `position_ids`
# (TFRobertaEmbeddings.create_position_ids_from_input_ids)
# - `1` here is `TFRobertaEmbeddings.padding_idx`
# avoid `padding_idx` in the past
input_ids = tf.where(input_ids == 1, 2, input_ids)
input_ids = input_ids[:1, :]
input_mask = input_mask[:1, :]
encoder_hidden_states = encoder_hidden_states[:1, :, :]
encoder_attention_mask = encoder_attention_mask[:1, :]
self.batch_size = 1
# first forward pass
outputs = model(
input_ids,
attention_mask=input_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=True,
)
past_key_values = outputs.past_key_values
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
next_attn_mask = ids_tensor((self.batch_size, 3), 2)
# append to next input_ids and
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
next_attention_mask = tf.concat([input_mask, next_attn_mask], axis=-1)
output_from_no_past = model(
next_input_ids,
attention_mask=next_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_hidden_states=True,
).hidden_states[0]
output_from_past = model(
next_tokens,
attention_mask=next_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
output_hidden_states=True,
).hidden_states[0]
self.parent.assertEqual(next_tokens.shape[1], output_from_past.shape[1])
# select random slice
random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1]))
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx]
output_from_past_slice = output_from_past[:, :, random_slice_idx]
# test that outputs are equal for slice
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3)
def create_and_check_for_masked_lm(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = TFRobertaForMaskedLM(config=config)
result = model([input_ids, input_mask, token_type_ids])
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_roberta_for_token_classification(
def create_and_check_for_token_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_labels = self.num_labels
@ -166,7 +506,7 @@ class TFRobertaModelTester:
result = model(inputs)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
def create_and_check_roberta_for_question_answering(
def create_and_check_for_question_answering(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = TFRobertaForQuestionAnswering(config=config)
@ -175,7 +515,7 @@ class TFRobertaModelTester:
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
def create_and_check_roberta_for_multiple_choice(
def create_and_check_for_multiple_choice(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_choices = self.num_choices
@ -231,29 +571,72 @@ class TFRobertaModelTest(TFModelTesterMixin, unittest.TestCase):
def test_config(self):
self.config_tester.run_common_tests()
def test_roberta_model(self):
def test_model(self):
"""Test the base model"""
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_roberta_model(*config_and_inputs)
self.model_tester.create_and_check_model(*config_and_inputs)
def test_causal_lm_base_model(self):
"""Test the base model of the causal LM model
is_deocder=True, no cross_attention, no encoder outputs
"""
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_causal_lm_base_model(*config_and_inputs)
def test_model_as_decoder(self):
"""Test the base model as a decoder (of an encoder-decoder architecture)
is_deocder=True + cross_attention + pass encoder outputs
"""
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
self.model_tester.create_and_check_model_as_decoder(*config_and_inputs)
def test_for_masked_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_roberta_for_masked_lm(*config_and_inputs)
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
def test_for_causal_lm(self):
"""Test the causal LM model"""
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_roberta_for_causal_lm(*config_and_inputs)
self.model_tester.create_and_check_causal_lm_model(*config_and_inputs)
def test_causal_lm_model_as_decoder(self):
"""Test the causal LM model as a decoder"""
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
self.model_tester.create_and_check_causal_lm_model_as_decoder(*config_and_inputs)
def test_causal_lm_model_past(self):
"""Test causal LM model with `past_key_values`"""
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_causal_lm_model_past(*config_and_inputs)
def test_causal_lm_model_past_with_attn_mask(self):
"""Test the causal LM model with `past_key_values` and `attention_mask`"""
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_causal_lm_model_past_with_attn_mask(*config_and_inputs)
def test_causal_lm_model_past_with_large_inputs(self):
"""Test the causal LM model with `past_key_values` and a longer decoder sequence length"""
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_causal_lm_model_past_large_inputs(*config_and_inputs)
def test_decoder_model_past_with_large_inputs(self):
"""Similar to `test_causal_lm_model_past_with_large_inputs` but with cross-attention"""
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
def test_for_token_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_roberta_for_token_classification(*config_and_inputs)
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
def test_for_question_answering(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_roberta_for_question_answering(*config_and_inputs)
self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
def test_for_multiple_choice(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_roberta_for_multiple_choice(*config_and_inputs)
self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)
@slow
def test_model_from_pretrained(self):