This commit is contained in:
Lysandre 2020-02-03 18:28:32 -05:00 committed by Lysandre Debut
parent 950c6a4f09
commit 5f96ebc0be
2 changed files with 77 additions and 77 deletions

View File

@ -52,35 +52,35 @@ class FlaubertModelTest(ModelTesterMixin, unittest.TestCase):
class FlaubertModelTester(object):
def __init__(
self,
parent,
batch_size=13,
seq_length=7,
is_training=True,
use_input_lengths=True,
use_token_type_ids=True,
use_labels=True,
gelu_activation=True,
sinusoidal_embeddings=False,
causal=False,
asm=False,
n_langs=2,
vocab_size=99,
n_special=0,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
type_sequence_label_size=2,
initializer_range=0.02,
num_labels=3,
num_choices=4,
summary_type="last",
use_proj=True,
scope=None,
self,
parent,
batch_size=13,
seq_length=7,
is_training=True,
use_input_lengths=True,
use_token_type_ids=True,
use_labels=True,
gelu_activation=True,
sinusoidal_embeddings=False,
causal=False,
asm=False,
n_langs=2,
vocab_size=99,
n_special=0,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
type_sequence_label_size=2,
initializer_range=0.02,
num_labels=3,
num_choices=4,
summary_type="last",
use_proj=True,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
@ -119,7 +119,7 @@ class FlaubertModelTest(ModelTesterMixin, unittest.TestCase):
input_lengths = None
if self.use_input_lengths:
input_lengths = (
ids_tensor([self.batch_size], vocab_size=2) + self.seq_length - 2
ids_tensor([self.batch_size], vocab_size=2) + self.seq_length - 2
) # small variation of seq_length
token_type_ids = None
@ -168,15 +168,15 @@ class FlaubertModelTest(ModelTesterMixin, unittest.TestCase):
self.parent.assertListEqual(list(result["loss"].size()), [])
def create_and_check_flaubert_model(
self,
config,
input_ids,
token_type_ids,
input_lengths,
sequence_labels,
token_labels,
is_impossible_labels,
input_mask,
self,
config,
input_ids,
token_type_ids,
input_lengths,
sequence_labels,
token_labels,
is_impossible_labels,
input_mask,
):
model = FlaubertModel(config=config)
model.to(torch_device)
@ -193,15 +193,15 @@ class FlaubertModelTest(ModelTesterMixin, unittest.TestCase):
)
def create_and_check_flaubert_lm_head(
self,
config,
input_ids,
token_type_ids,
input_lengths,
sequence_labels,
token_labels,
is_impossible_labels,
input_mask,
self,
config,
input_ids,
token_type_ids,
input_lengths,
sequence_labels,
token_labels,
is_impossible_labels,
input_mask,
):
model = FlaubertWithLMHeadModel(config)
model.to(torch_device)
@ -220,15 +220,15 @@ class FlaubertModelTest(ModelTesterMixin, unittest.TestCase):
)
def create_and_check_flaubert_simple_qa(
self,
config,
input_ids,
token_type_ids,
input_lengths,
sequence_labels,
token_labels,
is_impossible_labels,
input_mask,
self,
config,
input_ids,
token_type_ids,
input_lengths,
sequence_labels,
token_labels,
is_impossible_labels,
input_mask,
):
model = FlaubertForQuestionAnsweringSimple(config)
model.to(torch_device)
@ -249,15 +249,15 @@ class FlaubertModelTest(ModelTesterMixin, unittest.TestCase):
self.check_loss_output(result)
def create_and_check_flaubert_qa(
self,
config,
input_ids,
token_type_ids,
input_lengths,
sequence_labels,
token_labels,
is_impossible_labels,
input_mask,
self,
config,
input_ids,
token_type_ids,
input_lengths,
sequence_labels,
token_labels,
is_impossible_labels,
input_mask,
):
model = FlaubertForQuestionAnswering(config)
model.to(torch_device)
@ -316,15 +316,15 @@ class FlaubertModelTest(ModelTesterMixin, unittest.TestCase):
self.parent.assertListEqual(list(result["cls_logits"].size()), [self.batch_size])
def create_and_check_flaubert_sequence_classif(
self,
config,
input_ids,
token_type_ids,
input_lengths,
sequence_labels,
token_labels,
is_impossible_labels,
input_mask,
self,
config,
input_ids,
token_type_ids,
input_lengths,
sequence_labels,
token_labels,
is_impossible_labels,
input_mask,
):
model = FlaubertForSequenceClassification(config)
model.to(torch_device)

View File

@ -185,7 +185,7 @@ class RobertaModelTest(ModelTesterMixin, unittest.TestCase):
self.check_loss_output(result)
def create_and_check_roberta_for_multiple_choice(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_choices = self.num_choices
model = RobertaForMultipleChoice(config=config)
@ -208,7 +208,7 @@ class RobertaModelTest(ModelTesterMixin, unittest.TestCase):
self.check_loss_output(result)
def create_and_check_roberta_for_question_answering(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = RobertaForQuestionAnswering(config=config)
model.to(torch_device)