742 lines
28 KiB
Python
742 lines
28 KiB
Python
# coding=utf-8
|
|
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import random
|
|
import unittest
|
|
|
|
from transformers import XLNetConfig, is_torch_available
|
|
from transformers.testing_utils import require_torch, slow, torch_device
|
|
|
|
from ...generation.test_utils import GenerationTesterMixin
|
|
from ...test_configuration_common import ConfigTester
|
|
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
|
|
from ...test_pipeline_mixin import PipelineTesterMixin
|
|
|
|
|
|
if is_torch_available():
|
|
import torch
|
|
|
|
from transformers import (
|
|
XLNetForMultipleChoice,
|
|
XLNetForQuestionAnswering,
|
|
XLNetForQuestionAnsweringSimple,
|
|
XLNetForSequenceClassification,
|
|
XLNetForTokenClassification,
|
|
XLNetLMHeadModel,
|
|
XLNetModel,
|
|
)
|
|
|
|
|
|
class XLNetModelTester:
|
|
def __init__(
|
|
self,
|
|
parent,
|
|
batch_size=14,
|
|
seq_length=7,
|
|
mem_len=10,
|
|
clamp_len=-1,
|
|
reuse_len=15,
|
|
is_training=True,
|
|
use_labels=True,
|
|
vocab_size=99,
|
|
cutoffs=[10, 50, 80],
|
|
hidden_size=32,
|
|
num_attention_heads=4,
|
|
d_inner=128,
|
|
num_hidden_layers=2,
|
|
type_sequence_label_size=2,
|
|
untie_r=True,
|
|
bi_data=False,
|
|
same_length=False,
|
|
initializer_range=0.05,
|
|
seed=1,
|
|
type_vocab_size=2,
|
|
bos_token_id=1,
|
|
eos_token_id=2,
|
|
pad_token_id=5,
|
|
num_choices=4,
|
|
):
|
|
self.parent = parent
|
|
self.batch_size = 14
|
|
self.seq_length = 7
|
|
self.mem_len = 10
|
|
# self.key_len = seq_length + mem_len
|
|
self.clamp_len = -1
|
|
self.reuse_len = 15
|
|
self.is_training = True
|
|
self.use_labels = True
|
|
self.vocab_size = 99
|
|
self.cutoffs = [10, 50, 80]
|
|
self.hidden_size = 32
|
|
self.num_attention_heads = 4
|
|
self.d_inner = 128
|
|
self.num_hidden_layers = 5
|
|
self.type_sequence_label_size = 2
|
|
self.untie_r = True
|
|
self.bi_data = False
|
|
self.same_length = False
|
|
self.initializer_range = 0.05
|
|
self.seed = 1
|
|
self.type_vocab_size = 2
|
|
self.bos_token_id = 1
|
|
self.eos_token_id = 2
|
|
self.pad_token_id = 5
|
|
self.num_choices = 4
|
|
|
|
def prepare_config_and_inputs(self):
|
|
input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
|
input_ids_2 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
|
segment_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
|
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
|
|
|
input_ids_q = ids_tensor([self.batch_size, self.seq_length + 1], self.vocab_size)
|
|
perm_mask = torch.zeros(
|
|
self.batch_size,
|
|
self.seq_length + 1,
|
|
self.seq_length + 1,
|
|
dtype=torch.float,
|
|
device=torch_device,
|
|
)
|
|
perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token
|
|
target_mapping = torch.zeros(
|
|
self.batch_size,
|
|
1,
|
|
self.seq_length + 1,
|
|
dtype=torch.float,
|
|
device=torch_device,
|
|
)
|
|
target_mapping[:, 0, -1] = 1.0 # predict last token
|
|
|
|
sequence_labels = None
|
|
lm_labels = None
|
|
is_impossible_labels = None
|
|
token_labels = None
|
|
if self.use_labels:
|
|
lm_labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
|
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
|
is_impossible_labels = ids_tensor([self.batch_size], 2).float()
|
|
token_labels = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
|
|
|
config = self.get_config()
|
|
|
|
return (
|
|
config,
|
|
input_ids_1,
|
|
input_ids_2,
|
|
input_ids_q,
|
|
perm_mask,
|
|
input_mask,
|
|
target_mapping,
|
|
segment_ids,
|
|
lm_labels,
|
|
sequence_labels,
|
|
is_impossible_labels,
|
|
token_labels,
|
|
)
|
|
|
|
def get_config(self):
|
|
return XLNetConfig(
|
|
vocab_size=self.vocab_size,
|
|
d_model=self.hidden_size,
|
|
n_head=self.num_attention_heads,
|
|
d_inner=self.d_inner,
|
|
n_layer=self.num_hidden_layers,
|
|
untie_r=self.untie_r,
|
|
mem_len=self.mem_len,
|
|
clamp_len=self.clamp_len,
|
|
same_length=self.same_length,
|
|
reuse_len=self.reuse_len,
|
|
bi_data=self.bi_data,
|
|
initializer_range=self.initializer_range,
|
|
num_labels=self.type_sequence_label_size,
|
|
bos_token_id=self.bos_token_id,
|
|
pad_token_id=self.pad_token_id,
|
|
eos_token_id=self.eos_token_id,
|
|
)
|
|
|
|
def set_seed(self):
|
|
random.seed(self.seed)
|
|
torch.manual_seed(self.seed)
|
|
|
|
def create_and_check_xlnet_base_model(
|
|
self,
|
|
config,
|
|
input_ids_1,
|
|
input_ids_2,
|
|
input_ids_q,
|
|
perm_mask,
|
|
input_mask,
|
|
target_mapping,
|
|
segment_ids,
|
|
lm_labels,
|
|
sequence_labels,
|
|
is_impossible_labels,
|
|
token_labels,
|
|
):
|
|
model = XLNetModel(config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
|
|
result = model(input_ids_1, input_mask=input_mask)
|
|
result = model(input_ids_1, attention_mask=input_mask)
|
|
result = model(input_ids_1, token_type_ids=segment_ids)
|
|
result = model(input_ids_1)
|
|
|
|
config.mem_len = 0
|
|
model = XLNetModel(config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
base_model_output = model(input_ids_1)
|
|
self.parent.assertEqual(len(base_model_output), 2)
|
|
|
|
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
|
self.parent.assertListEqual(
|
|
[mem.shape for mem in result.mems],
|
|
[(self.seq_length, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
|
|
)
|
|
|
|
def create_and_check_use_mems_train(
|
|
self,
|
|
config,
|
|
input_ids_1,
|
|
input_ids_2,
|
|
input_ids_q,
|
|
perm_mask,
|
|
input_mask,
|
|
target_mapping,
|
|
segment_ids,
|
|
lm_labels,
|
|
sequence_labels,
|
|
is_impossible_labels,
|
|
token_labels,
|
|
):
|
|
model = XLNetForSequenceClassification(config)
|
|
model.to(torch_device)
|
|
model.train()
|
|
|
|
train_size = input_ids_1.shape[0]
|
|
|
|
batch_size = 4
|
|
for i in range(train_size // batch_size + 1):
|
|
input_ids = input_ids_1[i : (i + 1) * batch_size]
|
|
labels = sequence_labels[i : (i + 1) * batch_size]
|
|
outputs = model(input_ids=input_ids, labels=labels, return_dict=True)
|
|
self.parent.assertIsNone(outputs.mems)
|
|
self.parent.assertIsNotNone(outputs.loss)
|
|
|
|
def create_and_check_xlnet_model_use_mems(
|
|
self,
|
|
config,
|
|
input_ids_1,
|
|
input_ids_2,
|
|
input_ids_q,
|
|
perm_mask,
|
|
input_mask,
|
|
target_mapping,
|
|
segment_ids,
|
|
lm_labels,
|
|
sequence_labels,
|
|
is_impossible_labels,
|
|
token_labels,
|
|
):
|
|
model = XLNetModel(config=config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
|
|
# first forward pass
|
|
causal_mask = torch.ones(
|
|
input_ids_1.shape[0],
|
|
input_ids_1.shape[1],
|
|
input_ids_1.shape[1],
|
|
dtype=torch.float,
|
|
device=torch_device,
|
|
)
|
|
causal_mask = torch.triu(causal_mask, diagonal=0)
|
|
outputs_cache = model(input_ids_1, use_mems=True, perm_mask=causal_mask)
|
|
outputs_no_cache = model(input_ids_1, use_mems=False, perm_mask=causal_mask)
|
|
outputs_conf = model(input_ids_1)
|
|
|
|
self.parent.assertTrue(len(outputs_cache) == len(outputs_conf))
|
|
self.parent.assertTrue(len(outputs_cache) == len(outputs_no_cache) + 1)
|
|
|
|
output, mems = outputs_cache.to_tuple()
|
|
|
|
# create hypothetical next token and extent to next_input_ids
|
|
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
|
|
|
|
# append to next input_ids and token_type_ids
|
|
next_input_ids = torch.cat([input_ids_1, next_tokens], dim=-1)
|
|
|
|
# causal mask
|
|
causal_mask = torch.ones(
|
|
input_ids_1.shape[0],
|
|
input_ids_1.shape[1] + 1,
|
|
input_ids_1.shape[1] + 1,
|
|
dtype=torch.float,
|
|
device=torch_device,
|
|
)
|
|
causal_mask = torch.triu(causal_mask, diagonal=0)
|
|
single_mask = torch.ones(input_ids_1.shape[0], 1, 1, dtype=torch.float, device=torch_device)
|
|
|
|
# second forward pass
|
|
output_from_no_past = model(next_input_ids, perm_mask=causal_mask)["last_hidden_state"]
|
|
output_from_past = model(next_tokens, mems=mems, perm_mask=single_mask)["last_hidden_state"]
|
|
|
|
# select random slice
|
|
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
|
output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach()
|
|
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
|
|
|
|
# test that outputs are equal for slice
|
|
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
|
|
|
def create_and_check_xlnet_base_model_with_att_output(
|
|
self,
|
|
config,
|
|
input_ids_1,
|
|
input_ids_2,
|
|
input_ids_q,
|
|
perm_mask,
|
|
input_mask,
|
|
target_mapping,
|
|
segment_ids,
|
|
lm_labels,
|
|
sequence_labels,
|
|
is_impossible_labels,
|
|
token_labels,
|
|
):
|
|
model = XLNetModel(config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
|
|
attentions = model(input_ids_1, target_mapping=target_mapping, output_attentions=True)["attentions"]
|
|
|
|
self.parent.assertEqual(len(attentions), config.n_layer)
|
|
self.parent.assertIsInstance(attentions[0], tuple)
|
|
self.parent.assertEqual(len(attentions[0]), 2)
|
|
self.parent.assertTrue(attentions[0][0].shape, attentions[0][0].shape)
|
|
|
|
def create_and_check_xlnet_lm_head(
|
|
self,
|
|
config,
|
|
input_ids_1,
|
|
input_ids_2,
|
|
input_ids_q,
|
|
perm_mask,
|
|
input_mask,
|
|
target_mapping,
|
|
segment_ids,
|
|
lm_labels,
|
|
sequence_labels,
|
|
is_impossible_labels,
|
|
token_labels,
|
|
):
|
|
model = XLNetLMHeadModel(config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
|
|
result1 = model(input_ids_1, token_type_ids=segment_ids, labels=lm_labels)
|
|
|
|
result2 = model(input_ids_2, token_type_ids=segment_ids, labels=lm_labels, mems=result1.mems)
|
|
|
|
_ = model(input_ids_q, perm_mask=perm_mask, target_mapping=target_mapping)
|
|
|
|
self.parent.assertEqual(result1.loss.shape, ())
|
|
self.parent.assertEqual(result1.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
|
self.parent.assertListEqual(
|
|
[mem.shape for mem in result1.mems],
|
|
[(self.seq_length, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
|
|
)
|
|
|
|
self.parent.assertEqual(result2.loss.shape, ())
|
|
self.parent.assertEqual(result2.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
|
self.parent.assertListEqual(
|
|
[mem.shape for mem in result2.mems],
|
|
[(self.mem_len, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
|
|
)
|
|
|
|
def create_and_check_xlnet_qa(
|
|
self,
|
|
config,
|
|
input_ids_1,
|
|
input_ids_2,
|
|
input_ids_q,
|
|
perm_mask,
|
|
input_mask,
|
|
target_mapping,
|
|
segment_ids,
|
|
lm_labels,
|
|
sequence_labels,
|
|
is_impossible_labels,
|
|
token_labels,
|
|
):
|
|
model = XLNetForQuestionAnswering(config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
|
|
result = model(input_ids_1)
|
|
|
|
result_with_labels = model(
|
|
input_ids_1,
|
|
start_positions=sequence_labels,
|
|
end_positions=sequence_labels,
|
|
cls_index=sequence_labels,
|
|
is_impossible=is_impossible_labels,
|
|
p_mask=input_mask,
|
|
)
|
|
|
|
result_with_labels = model(
|
|
input_ids_1,
|
|
start_positions=sequence_labels,
|
|
end_positions=sequence_labels,
|
|
cls_index=sequence_labels,
|
|
is_impossible=is_impossible_labels,
|
|
)
|
|
|
|
total_loss, mems = result_with_labels.to_tuple()
|
|
|
|
result_with_labels = model(
|
|
input_ids_1,
|
|
start_positions=sequence_labels,
|
|
end_positions=sequence_labels,
|
|
)
|
|
|
|
total_loss, mems = result_with_labels.to_tuple()
|
|
|
|
self.parent.assertEqual(result_with_labels.loss.shape, ())
|
|
self.parent.assertEqual(result.start_top_log_probs.shape, (self.batch_size, model.config.start_n_top))
|
|
self.parent.assertEqual(result.start_top_index.shape, (self.batch_size, model.config.start_n_top))
|
|
self.parent.assertEqual(
|
|
result.end_top_log_probs.shape, (self.batch_size, model.config.start_n_top * model.config.end_n_top)
|
|
)
|
|
self.parent.assertEqual(
|
|
result.end_top_index.shape, (self.batch_size, model.config.start_n_top * model.config.end_n_top)
|
|
)
|
|
self.parent.assertEqual(result.cls_logits.shape, (self.batch_size,))
|
|
self.parent.assertListEqual(
|
|
[mem.shape for mem in result.mems],
|
|
[(self.seq_length, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
|
|
)
|
|
|
|
def create_and_check_xlnet_token_classif(
|
|
self,
|
|
config,
|
|
input_ids_1,
|
|
input_ids_2,
|
|
input_ids_q,
|
|
perm_mask,
|
|
input_mask,
|
|
target_mapping,
|
|
segment_ids,
|
|
lm_labels,
|
|
sequence_labels,
|
|
is_impossible_labels,
|
|
token_labels,
|
|
):
|
|
model = XLNetForTokenClassification(config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
|
|
result = model(input_ids_1)
|
|
result = model(input_ids_1, labels=token_labels)
|
|
|
|
self.parent.assertEqual(result.loss.shape, ())
|
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.type_sequence_label_size))
|
|
self.parent.assertListEqual(
|
|
[mem.shape for mem in result.mems],
|
|
[(self.seq_length, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
|
|
)
|
|
|
|
def create_and_check_xlnet_sequence_classif(
|
|
self,
|
|
config,
|
|
input_ids_1,
|
|
input_ids_2,
|
|
input_ids_q,
|
|
perm_mask,
|
|
input_mask,
|
|
target_mapping,
|
|
segment_ids,
|
|
lm_labels,
|
|
sequence_labels,
|
|
is_impossible_labels,
|
|
token_labels,
|
|
):
|
|
model = XLNetForSequenceClassification(config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
|
|
result = model(input_ids_1)
|
|
result = model(input_ids_1, labels=sequence_labels)
|
|
|
|
self.parent.assertEqual(result.loss.shape, ())
|
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
|
|
self.parent.assertListEqual(
|
|
[mem.shape for mem in result.mems],
|
|
[(self.seq_length, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
|
|
)
|
|
|
|
def prepare_config_and_inputs_for_common(self):
|
|
config_and_inputs = self.prepare_config_and_inputs()
|
|
(
|
|
config,
|
|
input_ids_1,
|
|
input_ids_2,
|
|
input_ids_q,
|
|
perm_mask,
|
|
input_mask,
|
|
target_mapping,
|
|
segment_ids,
|
|
lm_labels,
|
|
sequence_labels,
|
|
is_impossible_labels,
|
|
token_labels,
|
|
) = config_and_inputs
|
|
inputs_dict = {"input_ids": input_ids_1}
|
|
return config, inputs_dict
|
|
|
|
|
|
@require_torch
|
|
class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|
all_model_classes = (
|
|
(
|
|
XLNetModel,
|
|
XLNetLMHeadModel,
|
|
XLNetForTokenClassification,
|
|
XLNetForSequenceClassification,
|
|
XLNetForQuestionAnswering,
|
|
XLNetForQuestionAnsweringSimple,
|
|
XLNetForMultipleChoice,
|
|
)
|
|
if is_torch_available()
|
|
else ()
|
|
)
|
|
all_generative_model_classes = (
|
|
(XLNetLMHeadModel,) if is_torch_available() else ()
|
|
) # TODO (PVP): Check other models whether language generation is also applicable
|
|
pipeline_model_mapping = (
|
|
{
|
|
"feature-extraction": XLNetModel,
|
|
"question-answering": XLNetForQuestionAnsweringSimple,
|
|
"text-classification": XLNetForSequenceClassification,
|
|
"text-generation": XLNetLMHeadModel,
|
|
"token-classification": XLNetForTokenClassification,
|
|
"zero-shot": XLNetForSequenceClassification,
|
|
}
|
|
if is_torch_available()
|
|
else {}
|
|
)
|
|
fx_compatible = False
|
|
test_pruning = False
|
|
|
|
# TODO: Fix the failed tests
|
|
def is_pipeline_test_to_skip(
|
|
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
|
):
|
|
if pipeline_test_casse_name == "QAPipelineTests" and not tokenizer_name.endswith("Fast"):
|
|
return True
|
|
|
|
return False
|
|
|
|
# XLNet has 2 QA models -> need to manually set the correct labels for one of them here
|
|
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
|
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
|
|
|
if return_labels:
|
|
if model_class.__name__ == "XLNetForQuestionAnswering":
|
|
inputs_dict["start_positions"] = torch.zeros(
|
|
self.model_tester.batch_size, dtype=torch.long, device=torch_device
|
|
)
|
|
inputs_dict["end_positions"] = torch.zeros(
|
|
self.model_tester.batch_size, dtype=torch.long, device=torch_device
|
|
)
|
|
|
|
return inputs_dict
|
|
|
|
def setUp(self):
|
|
self.model_tester = XLNetModelTester(self)
|
|
self.config_tester = ConfigTester(self, config_class=XLNetConfig, d_inner=37)
|
|
|
|
def test_config(self):
|
|
self.config_tester.run_common_tests()
|
|
|
|
def test_xlnet_base_model(self):
|
|
self.model_tester.set_seed()
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
self.model_tester.create_and_check_xlnet_base_model(*config_and_inputs)
|
|
|
|
def test_xlnet_base_model_use_mems(self):
|
|
# checking that in auto-regressive mode, `use_mems` gives the same results
|
|
self.model_tester.set_seed()
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
self.model_tester.create_and_check_xlnet_model_use_mems(*config_and_inputs)
|
|
|
|
def test_seq_classification_use_mems_train(self):
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
self.model_tester.create_and_check_use_mems_train(*config_and_inputs)
|
|
|
|
def test_xlnet_base_model_with_att_output(self):
|
|
self.model_tester.set_seed()
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
self.model_tester.create_and_check_xlnet_base_model_with_att_output(*config_and_inputs)
|
|
|
|
def test_xlnet_lm_head(self):
|
|
self.model_tester.set_seed()
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
self.model_tester.create_and_check_xlnet_lm_head(*config_and_inputs)
|
|
|
|
def test_xlnet_sequence_classif(self):
|
|
self.model_tester.set_seed()
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
self.model_tester.create_and_check_xlnet_sequence_classif(*config_and_inputs)
|
|
|
|
def test_xlnet_token_classif(self):
|
|
self.model_tester.set_seed()
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
self.model_tester.create_and_check_xlnet_token_classif(*config_and_inputs)
|
|
|
|
def test_xlnet_qa(self):
|
|
self.model_tester.set_seed()
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
self.model_tester.create_and_check_xlnet_qa(*config_and_inputs)
|
|
|
|
def test_retain_grad_hidden_states_attentions(self):
|
|
# xlnet cannot keep gradients in attentions or hidden states
|
|
return
|
|
|
|
# overwrite from test_modeling_common
|
|
def _mock_init_weights(self, module):
|
|
if hasattr(module, "weight") and module.weight is not None:
|
|
module.weight.data.fill_(3)
|
|
if hasattr(module, "bias") and module.bias is not None:
|
|
module.bias.data.fill_(3)
|
|
|
|
for param in ["q", "k", "v", "o", "r", "r_r_bias", "r_s_bias", "r_w_bias", "seg_embed", "mask_emb"]:
|
|
if hasattr(module, param) and getattr(module, param) is not None:
|
|
weight = getattr(module, param)
|
|
weight.data.fill_(3)
|
|
|
|
def _check_hidden_states_for_generate(
|
|
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
|
):
|
|
self.assertIsInstance(hidden_states, tuple)
|
|
self.assertListEqual(
|
|
[isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states],
|
|
[True] * len(hidden_states),
|
|
)
|
|
self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups)
|
|
|
|
for idx, iter_hidden_states in enumerate(hidden_states):
|
|
# check hidden size
|
|
for i, layer_hidden_states in enumerate(iter_hidden_states):
|
|
# every 2nd tensor is from extra stream
|
|
if i % 2 != 0:
|
|
seq_len = 1
|
|
else:
|
|
# for first item dummy PAD token is appended so need one more
|
|
# else offset+dummy_token when using cache
|
|
seq_len = (min_length + 1) if idx == 0 else 3
|
|
|
|
expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size)
|
|
self.assertEqual(layer_hidden_states.shape, expected_shape)
|
|
|
|
def _check_attentions_for_generate(
|
|
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
|
):
|
|
self.assertIsInstance(attentions, tuple)
|
|
self.assertListEqual(
|
|
[isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions)
|
|
)
|
|
self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups)
|
|
|
|
for idx, attentions_item in enumerate(attentions):
|
|
for iter_attentions in attentions_item:
|
|
tgt_len = min_length
|
|
|
|
# for first item dummy PAD token is appended so need one more
|
|
# every token after consists of offset+dummy_token length when using cache
|
|
if idx == 0:
|
|
tgt_len += 1
|
|
else:
|
|
tgt_len = 3
|
|
|
|
src_len = min_length + idx + 1
|
|
|
|
expected_shape = (
|
|
batch_size * num_beam_groups,
|
|
config.num_attention_heads,
|
|
tgt_len,
|
|
src_len,
|
|
)
|
|
# check attn size
|
|
self.assertListEqual(
|
|
[layer_attention.shape for layer_attention in iter_attentions],
|
|
[expected_shape] * len(iter_attentions),
|
|
)
|
|
|
|
@slow
|
|
def test_model_from_pretrained(self):
|
|
model_name = "xlnet/xlnet-base-cased"
|
|
model = XLNetModel.from_pretrained(model_name)
|
|
self.assertIsNotNone(model)
|
|
|
|
|
|
@require_torch
|
|
class XLNetModelLanguageGenerationTest(unittest.TestCase):
|
|
@slow
|
|
def test_lm_generate_xlnet_base_cased(self):
|
|
model = XLNetLMHeadModel.from_pretrained("xlnet/xlnet-base-cased")
|
|
model.to(torch_device)
|
|
# fmt: off
|
|
input_ids = torch.tensor(
|
|
[
|
|
[
|
|
67, 2840, 19, 18, 1484, 20, 965, 29077, 8719, 1273, 21, 45, 273, 17, 10, 15048, 28, 27511, 21, 4185, 11, 41, 2444, 9, 32, 1025, 20, 8719, 26, 23, 673, 966, 19, 29077, 20643, 27511, 20822, 20643, 19, 17, 6616, 17511, 18, 8978, 20, 18, 777, 9, 19233, 1527, 17669, 19, 24, 673, 17, 28756, 150, 12943, 4354, 153, 27, 442, 37, 45, 668, 21, 24, 256, 20, 416, 22, 2771, 4901, 9, 12943, 4354, 153, 51, 24, 3004, 21, 28142, 23, 65, 20, 18, 416, 34, 24, 2958, 22947, 9, 1177, 45, 668, 3097, 13768, 23, 103, 28, 441, 148, 48, 20522, 19, 12943, 4354, 153, 12860, 34, 18, 326, 27, 17492, 684, 21, 6709, 9, 8585, 123, 266, 19, 12943, 4354, 153, 6872, 24, 3004, 20, 18, 9225, 2198, 19, 12717, 103, 22, 401, 24, 6348, 9, 12943, 4354, 153, 1068, 2768, 2286, 19, 33, 104, 19, 176, 24, 9313, 19, 20086, 28, 45, 10292, 9, 4, 3,
|
|
]
|
|
],
|
|
dtype=torch.long,
|
|
device=torch_device,
|
|
)
|
|
# fmt: on
|
|
# In 1991, the remains of Russian Tsar Nicholas II and his family
|
|
# (except for Alexei and Maria) are discovered.
|
|
# The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
|
|
# remainder of the story. 1883 Western Siberia,
|
|
# a young Grigori Rasputin is asked by his father and a group of men to perform magic.
|
|
# Rasputin has a vision and denounces one of the men as a horse thief. Although his
|
|
# father initially slaps him for making such an accusation, Rasputin watches as the
|
|
# man is chased outside and beaten. Twenty years later, Rasputin sees a vision of
|
|
# the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous,
|
|
# with people, even a bishop, begging for his blessing. """
|
|
|
|
# fmt: off
|
|
expected_output_ids = [
|
|
67, 2840, 19, 18, 1484, 20, 965, 29077, 8719, 1273, 21, 45, 273, 17, 10, 15048, 28, 27511, 21, 4185, 11, 41, 2444, 9, 32, 1025, 20, 8719, 26, 23, 673, 966, 19, 29077, 20643, 27511, 20822, 20643, 19, 17, 6616, 17511, 18, 8978, 20, 18, 777, 9, 19233, 1527, 17669, 19, 24, 673, 17, 28756, 150, 12943, 4354, 153, 27, 442, 37, 45, 668, 21, 24, 256, 20, 416, 22, 2771, 4901, 9, 12943, 4354, 153, 51, 24, 3004, 21, 28142, 23, 65, 20, 18, 416, 34, 24, 2958, 22947, 9, 1177, 45, 668, 3097, 13768, 23, 103, 28, 441, 148, 48, 20522, 19, 12943, 4354, 153, 12860, 34, 18, 326, 27, 17492, 684, 21, 6709, 9, 8585, 123, 266, 19, 12943, 4354, 153, 6872, 24, 3004, 20, 18, 9225, 2198, 19, 12717, 103, 22, 401, 24, 6348, 9, 12943, 4354, 153, 1068, 2768, 2286, 19, 33, 104, 19, 176, 24, 9313, 19, 20086, 28, 45, 10292, 9, 4, 3, 19, 12943, 4354, 153, 27, 442, 22, 2771, 4901, 9, 69, 27, 442, 22, 2771, 24, 11335, 20, 18, 9225, 2198, 9, 69, 27, 442, 22, 2771, 24, 11335, 20, 18, 9225, 2198, 9, 69, 27, 442, 22, 2771,
|
|
]
|
|
# fmt: on
|
|
# In 1991, the remains of Russian Tsar Nicholas II and his family (except for Alexei and Maria)
|
|
# are discovered. The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich,
|
|
# narrates the remainder of the story. 1883 Western Siberia, a young Grigori Rasputin
|
|
# is asked by his father and a group of men to perform magic. Rasputin has a vision and
|
|
# denounces one of the men as a horse thief. Although his father initially slaps
|
|
# him for making such an accusation, Rasputin watches as the man is chased outside and beaten.
|
|
# Twenty years later, Rasputin sees a vision of the Virgin Mary, prompting him to become a priest.
|
|
# Rasputin quickly becomes famous, with people, even a bishop, begging for his blessing.
|
|
# <sep><cls>, Rasputin is asked to perform magic. He is asked to perform a ritual of the Virgin Mary.
|
|
# He is asked to perform a ritual of the Virgin Mary. He is asked to perform
|
|
|
|
output_ids = model.generate(input_ids, max_length=200, do_sample=False)
|
|
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
|