transformers/tests/test_modeling_tf_pegasus.py

375 lines
17 KiB
Python

# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest
from transformers import AutoTokenizer, PegasusConfig, is_tf_available
from transformers.file_utils import cached_property
from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow
from .test_configuration_common import ConfigTester
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
if is_tf_available():
import tensorflow as tf
from transformers import TFAutoModelForSeq2SeqLM, TFPegasusForConditionalGeneration, TFPegasusModel
@require_tf
class TFPegasusModelTester:
config_cls = PegasusConfig
config_updates = {}
hidden_act = "gelu"
def __init__(
self,
parent,
batch_size=13,
seq_length=7,
is_training=True,
use_labels=False,
vocab_size=99,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
intermediate_size=37,
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=40,
eos_token_id=2,
pad_token_id=1,
bos_token_id=0,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.eos_token_id = eos_token_id
self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
def prepare_config_and_inputs_for_common(self):
input_ids = ids_tensor([self.batch_size, self.seq_length - 1], self.vocab_size)
eos_tensor = tf.expand_dims(tf.constant([self.eos_token_id] * self.batch_size), 1)
input_ids = tf.concat([input_ids, eos_tensor], axis=1)
decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
config = self.config_cls(
vocab_size=self.vocab_size,
d_model=self.hidden_size,
encoder_layers=self.num_hidden_layers,
decoder_layers=self.num_hidden_layers,
encoder_attention_heads=self.num_attention_heads,
decoder_attention_heads=self.num_attention_heads,
encoder_ffn_dim=self.intermediate_size,
decoder_ffn_dim=self.intermediate_size,
dropout=self.hidden_dropout_prob,
attention_dropout=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings,
eos_token_ids=[2],
bos_token_id=self.bos_token_id,
pad_token_id=self.pad_token_id,
decoder_start_token_id=self.pad_token_id,
**self.config_updates,
)
inputs_dict = prepare_pegasus_inputs_dict(config, input_ids, decoder_input_ids)
return config, inputs_dict
def check_decoder_model_past_large_inputs(self, config, inputs_dict):
model = TFPegasusModel(config=config).get_decoder()
input_ids = inputs_dict["input_ids"]
input_ids = input_ids[:1, :]
attention_mask = inputs_dict["attention_mask"][:1, :]
head_mask = inputs_dict["head_mask"]
self.batch_size = 1
# first forward pass
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
output, past_key_values = outputs.to_tuple()
past_key_values = past_key_values[1]
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
next_attn_mask = tf.cast(ids_tensor((self.batch_size, 3), 2), tf.int8)
# append to next input_ids and
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
next_attention_mask = tf.concat([attention_mask, next_attn_mask], axis=-1)
output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)[0]
output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)[0]
self.parent.assertEqual(next_tokens.shape[1], output_from_past.shape[1])
# select random slice
random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1]))
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx]
output_from_past_slice = output_from_past[:, :, random_slice_idx]
# test that outputs are equal for slice
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3)
def prepare_pegasus_inputs_dict(
config,
input_ids,
decoder_input_ids,
attention_mask=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
):
if attention_mask is None:
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
if decoder_attention_mask is None:
decoder_attention_mask = tf.concat(
[
tf.ones(decoder_input_ids[:, :1].shape, dtype=tf.int8),
tf.cast(tf.math.not_equal(decoder_input_ids[:, 1:], config.pad_token_id), tf.int8),
],
axis=-1,
)
if head_mask is None:
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
if decoder_head_mask is None:
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
if cross_attn_head_mask is None:
cross_attn_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
return {
"input_ids": input_ids,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
}
@require_tf
class TFPegasusModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = (TFPegasusForConditionalGeneration, TFPegasusModel) if is_tf_available() else ()
all_generative_model_classes = (TFPegasusForConditionalGeneration,) if is_tf_available() else ()
is_encoder_decoder = True
test_pruning = False
test_onnx = False
def setUp(self):
self.model_tester = TFPegasusModelTester(self)
self.config_tester = ConfigTester(self, config_class=PegasusConfig)
def test_config(self):
self.config_tester.run_common_tests()
def test_decoder_model_past_large_inputs(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
self.model_tester.check_decoder_model_past_large_inputs(*config_and_inputs)
def test_compile_tf_model(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08, clipnorm=1.0)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metric = tf.keras.metrics.SparseCategoricalAccuracy("accuracy")
model_class = self.all_generative_model_classes[0]
input_ids = {
"decoder_input_ids": tf.keras.Input(batch_shape=(2, 2000), name="decoder_input_ids", dtype="int32"),
"input_ids": tf.keras.Input(batch_shape=(2, 2000), name="input_ids", dtype="int32"),
}
# Prepare our model
model = model_class(config)
model(self._prepare_for_class(inputs_dict, model_class)) # Model must be called before saving.
# Let's load it from the disk to be sure we can use pretrained weights
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname)
outputs_dict = model(input_ids)
hidden_states = outputs_dict[0]
# Add a dense layer on top to test integration with other keras modules
outputs = tf.keras.layers.Dense(2, activation="softmax", name="outputs")(hidden_states)
# Compile extended model
extended_model = tf.keras.Model(inputs=[input_ids], outputs=[outputs])
extended_model.compile(optimizer=optimizer, loss=loss, metrics=[metric])
def test_model_common_attributes(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer)
if model_class in self.all_generative_model_classes:
x = model.get_output_embeddings()
assert isinstance(x, tf.keras.layers.Layer)
name = model.get_bias()
assert isinstance(name, dict)
for k, v in name.items():
assert isinstance(v, tf.Variable)
else:
x = model.get_output_embeddings()
assert x is None
name = model.get_bias()
assert name is None
def test_saved_model_creation(self):
# This test is too long (>30sec) and makes fail the CI
pass
def test_resize_token_embeddings(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
def _get_word_embedding_weight(model, embedding_layer):
if hasattr(embedding_layer, "weight"):
return embedding_layer.weight
else:
# Here we build the word embeddings weights if not exists.
# And then we retry to get the attribute once built.
model(model.dummy_inputs)
if hasattr(embedding_layer, "weight"):
return embedding_layer.weight
else:
return None
for model_class in self.all_model_classes:
for size in [config.vocab_size - 10, config.vocab_size + 10, None]:
# build the embeddings
model = model_class(config=config)
old_input_embeddings = _get_word_embedding_weight(model, model.get_input_embeddings())
old_output_embeddings = _get_word_embedding_weight(model, model.get_output_embeddings())
old_final_logits_bias = model.get_bias()
# reshape the embeddings
model.resize_token_embeddings(size)
new_input_embeddings = _get_word_embedding_weight(model, model.get_input_embeddings())
new_output_embeddings = _get_word_embedding_weight(model, model.get_output_embeddings())
new_final_logits_bias = model.get_bias()
# check that the resized embeddings size matches the desired size.
assert_size = size if size is not None else config.vocab_size
self.assertEqual(new_input_embeddings.shape[0], assert_size)
# check that weights remain the same after resizing
models_equal = True
for p1, p2 in zip(old_input_embeddings.value(), new_input_embeddings.value()):
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
models_equal = False
self.assertTrue(models_equal)
if old_output_embeddings is not None and new_output_embeddings is not None:
self.assertEqual(new_output_embeddings.shape[0], assert_size)
models_equal = True
for p1, p2 in zip(old_output_embeddings.value(), new_output_embeddings.value()):
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
models_equal = False
self.assertTrue(models_equal)
if old_final_logits_bias is not None and new_final_logits_bias is not None:
old_final_logits_bias = old_final_logits_bias["final_logits_bias"]
new_final_logits_bias = new_final_logits_bias["final_logits_bias"]
self.assertEqual(new_final_logits_bias.shape[0], 1)
self.assertEqual(new_final_logits_bias.shape[1], assert_size)
models_equal = True
for old, new in zip(old_final_logits_bias.value(), new_final_logits_bias.value()):
for p1, p2 in zip(old, new):
if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0:
models_equal = False
self.assertTrue(models_equal)
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
if a is None and b is None:
return True
try:
if tf.debugging.assert_near(a, b, atol=atol):
return True
raise
except Exception:
if len(prefix) > 0:
prefix = f"{prefix}: "
raise AssertionError(f"{prefix}{a} != {b}")
def _long_tensor(tok_lst):
return tf.constant(tok_lst, dtype=tf.int32)
@require_sentencepiece
@require_tokenizers
@require_tf
class TFPegasusIntegrationTests(unittest.TestCase):
src_text = [
""" PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow.""",
""" The London trio are up for best UK act and best album, as well as getting two nominations in the best song category."We got told like this morning 'Oh I think you're nominated'", said Dappy."And I was like 'Oh yeah, which one?' And now we've got nominated for four awards. I mean, wow!"Bandmate Fazer added: "We thought it's best of us to come down and mingle with everyone and say hello to the cameras. And now we find we've got four nominations."The band have two shots at the best song prize, getting the nod for their Tynchy Stryder collaboration Number One, and single Strong Again.Their album Uncle B will also go up against records by the likes of Beyonce and Kanye West.N-Dubz picked up the best newcomer Mobo in 2007, but female member Tulisa said they wouldn't be too disappointed if they didn't win this time around."At the end of the day we're grateful to be where we are in our careers."If it don't happen then it don't happen - live to fight another day and keep on making albums and hits for the fans."Dappy also revealed they could be performing live several times on the night.The group will be doing Number One and also a possible rendition of the War Child single, I Got Soul.The charity song is a re-working of The Killers' All These Things That I've Done and is set to feature artists like Chipmunk, Ironik and Pixie Lott.This year's Mobos will be held outside of London for the first time, in Glasgow on 30 September.N-Dubz said they were looking forward to performing for their Scottish fans and boasted about their recent shows north of the border."We just done Edinburgh the other day," said Dappy."We smashed up an N-Dubz show over there. We done Aberdeen about three or four months ago - we smashed up that show over there! Everywhere we go we smash it up!" """,
]
expected_text = [
"California's largest electricity provider has cut power to hundreds of thousands of customers in an effort to reduce the risk of wildfires.",
'N-Dubz have revealed they\'re "grateful" to have been nominated for four Mobo Awards.',
] # differs slightly from pytorch, likely due to numerical differences in linear layers
model_name = "google/pegasus-xsum"
@cached_property
def tokenizer(self):
return AutoTokenizer.from_pretrained(self.model_name)
@cached_property
def model(self):
model = TFAutoModelForSeq2SeqLM.from_pretrained(self.model_name)
return model
def _assert_generated_batch_equal_expected(self, **tokenizer_kwargs):
generated_words = self.translate_src_text(**tokenizer_kwargs)
assert self.expected_text == generated_words
def translate_src_text(self, **tokenizer_kwargs):
model_inputs = self.tokenizer(self.src_text, **tokenizer_kwargs, padding=True, return_tensors="tf")
generated_ids = self.model.generate(
model_inputs.input_ids,
attention_mask=model_inputs.attention_mask,
num_beams=2,
use_cache=True,
)
generated_words = self.tokenizer.batch_decode(generated_ids.numpy(), skip_special_tokens=True)
return generated_words
@slow
def test_batch_generation(self):
self._assert_generated_batch_equal_expected()