485 lines
20 KiB
Python
485 lines
20 KiB
Python
# coding=utf-8
|
|
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Testing suite for the PyTorch ConvBERT model."""
|
|
|
|
import os
|
|
import tempfile
|
|
import unittest
|
|
|
|
from transformers import ConvBertConfig, is_torch_available
|
|
from transformers.models.auto import get_values
|
|
from transformers.testing_utils import require_torch, require_torch_accelerator, slow, torch_device
|
|
|
|
from ...test_configuration_common import ConfigTester
|
|
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
|
from ...test_pipeline_mixin import PipelineTesterMixin
|
|
|
|
|
|
if is_torch_available():
|
|
import torch
|
|
|
|
from transformers import (
|
|
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
|
ConvBertForMaskedLM,
|
|
ConvBertForMultipleChoice,
|
|
ConvBertForQuestionAnswering,
|
|
ConvBertForSequenceClassification,
|
|
ConvBertForTokenClassification,
|
|
ConvBertModel,
|
|
)
|
|
|
|
|
|
class ConvBertModelTester:
|
|
def __init__(
|
|
self,
|
|
parent,
|
|
batch_size=13,
|
|
seq_length=7,
|
|
is_training=True,
|
|
use_input_mask=True,
|
|
use_token_type_ids=True,
|
|
use_labels=True,
|
|
vocab_size=99,
|
|
hidden_size=32,
|
|
num_hidden_layers=2,
|
|
num_attention_heads=4,
|
|
intermediate_size=37,
|
|
hidden_act="gelu",
|
|
hidden_dropout_prob=0.1,
|
|
attention_probs_dropout_prob=0.1,
|
|
max_position_embeddings=512,
|
|
type_vocab_size=16,
|
|
type_sequence_label_size=2,
|
|
initializer_range=0.02,
|
|
num_labels=3,
|
|
num_choices=4,
|
|
scope=None,
|
|
):
|
|
self.parent = parent
|
|
self.batch_size = batch_size
|
|
self.seq_length = seq_length
|
|
self.is_training = is_training
|
|
self.use_input_mask = use_input_mask
|
|
self.use_token_type_ids = use_token_type_ids
|
|
self.use_labels = use_labels
|
|
self.vocab_size = vocab_size
|
|
self.hidden_size = hidden_size
|
|
self.num_hidden_layers = num_hidden_layers
|
|
self.num_attention_heads = num_attention_heads
|
|
self.intermediate_size = intermediate_size
|
|
self.hidden_act = hidden_act
|
|
self.hidden_dropout_prob = hidden_dropout_prob
|
|
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
|
self.max_position_embeddings = max_position_embeddings
|
|
self.type_vocab_size = type_vocab_size
|
|
self.type_sequence_label_size = type_sequence_label_size
|
|
self.initializer_range = initializer_range
|
|
self.num_labels = num_labels
|
|
self.num_choices = num_choices
|
|
self.scope = scope
|
|
|
|
def prepare_config_and_inputs(self):
|
|
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
|
|
|
input_mask = None
|
|
if self.use_input_mask:
|
|
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
|
|
|
token_type_ids = None
|
|
if self.use_token_type_ids:
|
|
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
|
|
|
sequence_labels = None
|
|
token_labels = None
|
|
choice_labels = None
|
|
if self.use_labels:
|
|
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
|
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
|
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
|
|
|
config = self.get_config()
|
|
|
|
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
|
|
def get_config(self):
|
|
return ConvBertConfig(
|
|
vocab_size=self.vocab_size,
|
|
hidden_size=self.hidden_size,
|
|
num_hidden_layers=self.num_hidden_layers,
|
|
num_attention_heads=self.num_attention_heads,
|
|
intermediate_size=self.intermediate_size,
|
|
hidden_act=self.hidden_act,
|
|
hidden_dropout_prob=self.hidden_dropout_prob,
|
|
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
|
max_position_embeddings=self.max_position_embeddings,
|
|
type_vocab_size=self.type_vocab_size,
|
|
is_decoder=False,
|
|
initializer_range=self.initializer_range,
|
|
)
|
|
|
|
def prepare_config_and_inputs_for_decoder(self):
|
|
(
|
|
config,
|
|
input_ids,
|
|
token_type_ids,
|
|
input_mask,
|
|
sequence_labels,
|
|
token_labels,
|
|
choice_labels,
|
|
) = self.prepare_config_and_inputs()
|
|
|
|
config.is_decoder = True
|
|
encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
|
|
encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
|
|
|
return (
|
|
config,
|
|
input_ids,
|
|
token_type_ids,
|
|
input_mask,
|
|
sequence_labels,
|
|
token_labels,
|
|
choice_labels,
|
|
encoder_hidden_states,
|
|
encoder_attention_mask,
|
|
)
|
|
|
|
def create_and_check_model(
|
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
):
|
|
model = ConvBertModel(config=config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
|
|
result = model(input_ids, token_type_ids=token_type_ids)
|
|
result = model(input_ids)
|
|
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
|
|
|
def create_and_check_for_masked_lm(
|
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
):
|
|
model = ConvBertForMaskedLM(config=config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
|
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
|
|
|
def create_and_check_for_question_answering(
|
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
):
|
|
model = ConvBertForQuestionAnswering(config=config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
result = model(
|
|
input_ids,
|
|
attention_mask=input_mask,
|
|
token_type_ids=token_type_ids,
|
|
start_positions=sequence_labels,
|
|
end_positions=sequence_labels,
|
|
)
|
|
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
|
|
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
|
|
|
|
def create_and_check_for_sequence_classification(
|
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
):
|
|
config.num_labels = self.num_labels
|
|
model = ConvBertForSequenceClassification(config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
|
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
|
|
|
|
def create_and_check_for_token_classification(
|
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
):
|
|
config.num_labels = self.num_labels
|
|
model = ConvBertForTokenClassification(config=config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
|
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
|
|
|
|
def create_and_check_for_multiple_choice(
|
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
):
|
|
config.num_choices = self.num_choices
|
|
model = ConvBertForMultipleChoice(config=config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
|
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
|
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
|
result = model(
|
|
multiple_choice_inputs_ids,
|
|
attention_mask=multiple_choice_input_mask,
|
|
token_type_ids=multiple_choice_token_type_ids,
|
|
labels=choice_labels,
|
|
)
|
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
|
|
|
|
def prepare_config_and_inputs_for_common(self):
|
|
config_and_inputs = self.prepare_config_and_inputs()
|
|
(
|
|
config,
|
|
input_ids,
|
|
token_type_ids,
|
|
input_mask,
|
|
sequence_labels,
|
|
token_labels,
|
|
choice_labels,
|
|
) = config_and_inputs
|
|
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
|
|
return config, inputs_dict
|
|
|
|
|
|
@require_torch
|
|
class ConvBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|
all_model_classes = (
|
|
(
|
|
ConvBertModel,
|
|
ConvBertForMaskedLM,
|
|
ConvBertForMultipleChoice,
|
|
ConvBertForQuestionAnswering,
|
|
ConvBertForSequenceClassification,
|
|
ConvBertForTokenClassification,
|
|
)
|
|
if is_torch_available()
|
|
else ()
|
|
)
|
|
pipeline_model_mapping = (
|
|
{
|
|
"feature-extraction": ConvBertModel,
|
|
"fill-mask": ConvBertForMaskedLM,
|
|
"question-answering": ConvBertForQuestionAnswering,
|
|
"text-classification": ConvBertForSequenceClassification,
|
|
"token-classification": ConvBertForTokenClassification,
|
|
"zero-shot": ConvBertForSequenceClassification,
|
|
}
|
|
if is_torch_available()
|
|
else {}
|
|
)
|
|
test_pruning = False
|
|
test_head_masking = False
|
|
|
|
def setUp(self):
|
|
self.model_tester = ConvBertModelTester(self)
|
|
self.config_tester = ConfigTester(self, config_class=ConvBertConfig, hidden_size=37)
|
|
|
|
def test_config(self):
|
|
self.config_tester.run_common_tests()
|
|
|
|
def test_model(self):
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
|
|
|
def test_for_masked_lm(self):
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
|
|
|
|
def test_for_multiple_choice(self):
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)
|
|
|
|
def test_for_question_answering(self):
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
|
|
|
|
def test_for_sequence_classification(self):
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
|
|
|
|
def test_for_token_classification(self):
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
|
|
|
|
@slow
|
|
def test_model_from_pretrained(self):
|
|
model_name = "YituTech/conv-bert-base"
|
|
model = ConvBertModel.from_pretrained(model_name)
|
|
self.assertIsNotNone(model)
|
|
|
|
def test_attention_outputs(self):
|
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
config.return_dict = True
|
|
|
|
seq_len = getattr(self.model_tester, "seq_length", None)
|
|
decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len)
|
|
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
|
|
decoder_key_length = getattr(self.model_tester, "decoder_key_length", decoder_seq_length)
|
|
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
|
|
chunk_length = getattr(self.model_tester, "chunk_length", None)
|
|
if chunk_length is not None and hasattr(self.model_tester, "num_hashes"):
|
|
encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes
|
|
|
|
for model_class in self.all_model_classes:
|
|
inputs_dict["output_attentions"] = True
|
|
inputs_dict["output_hidden_states"] = False
|
|
config.return_dict = True
|
|
model = model_class(config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
with torch.no_grad():
|
|
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
|
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
|
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
|
|
|
# check that output_attentions also work using config
|
|
del inputs_dict["output_attentions"]
|
|
config.output_attentions = True
|
|
model = model_class(config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
with torch.no_grad():
|
|
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
|
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
|
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
|
|
|
if chunk_length is not None:
|
|
self.assertListEqual(
|
|
list(attentions[0].shape[-4:]),
|
|
[self.model_tester.num_attention_heads / 2, encoder_seq_length, chunk_length, encoder_key_length],
|
|
)
|
|
else:
|
|
self.assertListEqual(
|
|
list(attentions[0].shape[-3:]),
|
|
[self.model_tester.num_attention_heads / 2, encoder_seq_length, encoder_key_length],
|
|
)
|
|
out_len = len(outputs)
|
|
|
|
if self.is_encoder_decoder:
|
|
correct_outlen = 5
|
|
|
|
# loss is at first position
|
|
if "labels" in inputs_dict:
|
|
correct_outlen += 1 # loss is added to beginning
|
|
# Question Answering model returns start_logits and end_logits
|
|
if model_class in get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING):
|
|
correct_outlen += 1 # start_logits and end_logits instead of only 1 output
|
|
if "past_key_values" in outputs:
|
|
correct_outlen += 1 # past_key_values have been returned
|
|
|
|
self.assertEqual(out_len, correct_outlen)
|
|
|
|
# decoder attentions
|
|
decoder_attentions = outputs.decoder_attentions
|
|
self.assertIsInstance(decoder_attentions, (list, tuple))
|
|
self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
|
|
self.assertListEqual(
|
|
list(decoder_attentions[0].shape[-3:]),
|
|
[self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
|
|
)
|
|
|
|
# cross attentions
|
|
cross_attentions = outputs.cross_attentions
|
|
self.assertIsInstance(cross_attentions, (list, tuple))
|
|
self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers)
|
|
self.assertListEqual(
|
|
list(cross_attentions[0].shape[-3:]),
|
|
[
|
|
self.model_tester.num_attention_heads,
|
|
decoder_seq_length,
|
|
encoder_key_length,
|
|
],
|
|
)
|
|
|
|
# Check attention is always last and order is fine
|
|
inputs_dict["output_attentions"] = True
|
|
inputs_dict["output_hidden_states"] = True
|
|
model = model_class(config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
with torch.no_grad():
|
|
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
|
|
|
if hasattr(self.model_tester, "num_hidden_states_types"):
|
|
added_hidden_states = self.model_tester.num_hidden_states_types
|
|
elif self.is_encoder_decoder:
|
|
added_hidden_states = 2
|
|
else:
|
|
added_hidden_states = 1
|
|
self.assertEqual(out_len + added_hidden_states, len(outputs))
|
|
|
|
self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
|
|
|
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
|
|
if chunk_length is not None:
|
|
self.assertListEqual(
|
|
list(self_attentions[0].shape[-4:]),
|
|
[self.model_tester.num_attention_heads / 2, encoder_seq_length, chunk_length, encoder_key_length],
|
|
)
|
|
else:
|
|
self.assertListEqual(
|
|
list(self_attentions[0].shape[-3:]),
|
|
[self.model_tester.num_attention_heads / 2, encoder_seq_length, encoder_key_length],
|
|
)
|
|
|
|
@slow
|
|
@require_torch_accelerator
|
|
def test_torchscript_device_change(self):
|
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
for model_class in self.all_model_classes:
|
|
# ConvBertForMultipleChoice behaves incorrectly in JIT environments.
|
|
if model_class == ConvBertForMultipleChoice:
|
|
return
|
|
|
|
config.torchscript = True
|
|
model = model_class(config=config)
|
|
|
|
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
|
traced_model = torch.jit.trace(
|
|
model, (inputs_dict["input_ids"].to("cpu"), inputs_dict["attention_mask"].to("cpu"))
|
|
)
|
|
|
|
with tempfile.TemporaryDirectory() as tmp:
|
|
torch.jit.save(traced_model, os.path.join(tmp, "traced_model.pt"))
|
|
loaded = torch.jit.load(os.path.join(tmp, "traced_model.pt"), map_location=torch_device)
|
|
loaded(inputs_dict["input_ids"].to(torch_device), inputs_dict["attention_mask"].to(torch_device))
|
|
|
|
def test_model_for_input_embeds(self):
|
|
batch_size = 2
|
|
seq_length = 10
|
|
inputs_embeds = torch.rand([batch_size, seq_length, 768], device=torch_device)
|
|
config = self.model_tester.get_config()
|
|
model = ConvBertModel(config=config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
result = model(inputs_embeds=inputs_embeds)
|
|
self.assertEqual(result.last_hidden_state.shape, (batch_size, seq_length, config.hidden_size))
|
|
|
|
def test_reducing_attention_heads(self):
|
|
config, *inputs_dict = self.model_tester.prepare_config_and_inputs()
|
|
config.head_ratio = 4
|
|
self.model_tester.create_and_check_for_masked_lm(config, *inputs_dict)
|
|
|
|
|
|
@require_torch
|
|
class ConvBertModelIntegrationTest(unittest.TestCase):
|
|
@slow
|
|
def test_inference_no_head(self):
|
|
model = ConvBertModel.from_pretrained("YituTech/conv-bert-base")
|
|
input_ids = torch.tensor([[1, 2, 3, 4, 5, 6]])
|
|
with torch.no_grad():
|
|
output = model(input_ids)[0]
|
|
|
|
expected_shape = torch.Size((1, 6, 768))
|
|
self.assertEqual(output.shape, expected_shape)
|
|
|
|
expected_slice = torch.tensor(
|
|
[[[-0.0864, -0.4898, -0.3677], [0.1434, -0.2952, -0.7640], [-0.0112, -0.4432, -0.5432]]]
|
|
)
|
|
|
|
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
|