cleanup tf unittests: part 2 (#6260)

* cleanup torch unittests: part 2

* remove trailing comma added by isort, and which breaks flake

* one more comma

* revert odd balls

* part 3: odd cases

* more ["key"] -> .key refactoring

* .numpy() is not needed

* more unncessary .numpy() removed

* more simplification
This commit is contained in:
Stas Bekman 2020-08-13 01:29:06 -07:00 committed by GitHub
parent bc820476a5
commit e983da0e7d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 159 additions and 239 deletions

View File

@ -148,10 +148,10 @@ class TFXxxModelTest(TFModelTesterMixin, unittest.TestCase):
result = model(input_ids)
self.parent.assertListEqual(
list(result["last_hidden_state"].shape), [self.batch_size, self.seq_length, self.hidden_size]
self.parent.assertEqual(
result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)
)
self.parent.assertListEqual(list(result["pooler_output"].shape), [self.batch_size, self.hidden_size])
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
def create_and_check_xxx_for_masked_lm(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
@ -159,9 +159,7 @@ class TFXxxModelTest(TFModelTesterMixin, unittest.TestCase):
model = TFXxxForMaskedLM(config=config)
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
result = model(inputs)
self.parent.assertListEqual(
list(result["logits"].shape), [self.batch_size, self.seq_length, self.vocab_size]
)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_xxx_for_sequence_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
@ -170,7 +168,7 @@ class TFXxxModelTest(TFModelTesterMixin, unittest.TestCase):
model = TFXxxForSequenceClassification(config=config)
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
result = model(inputs)
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_labels])
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
def create_and_check_bert_for_multiple_choice(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
@ -186,7 +184,7 @@ class TFXxxModelTest(TFModelTesterMixin, unittest.TestCase):
"token_type_ids": multiple_choice_token_type_ids,
}
result = model(inputs)
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_choices])
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
def create_and_check_xxx_for_token_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
@ -195,9 +193,7 @@ class TFXxxModelTest(TFModelTesterMixin, unittest.TestCase):
model = TFXxxForTokenClassification(config=config)
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
result = model(inputs)
self.parent.assertListEqual(
list(result["logits"].shape), [self.batch_size, self.seq_length, self.num_labels]
)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
def create_and_check_xxx_for_question_answering(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
@ -205,8 +201,8 @@ class TFXxxModelTest(TFModelTesterMixin, unittest.TestCase):
model = TFXxxForQuestionAnswering(config=config)
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
result = model(inputs)
self.parent.assertListEqual(list(result["start_logits"].shape), [self.batch_size, self.seq_length])
self.parent.assertListEqual(list(result["end_logits"].shape), [self.batch_size, self.seq_length])
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()

View File

@ -117,7 +117,7 @@ class CTRLModelTester:
model(input_ids, token_type_ids=token_type_ids)
result = model(input_ids)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
self.parent.assertEqual(len(result["past_key_values"]), config.n_layer)
self.parent.assertEqual(len(result.past_key_values), config.n_layer)
def create_and_check_lm_head_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
model = CTRLLMHeadModel(config)

View File

@ -152,7 +152,7 @@ class GPT2ModelTester:
result = model(input_ids)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
self.parent.assertEqual(len(result["past_key_values"]), config.n_layer)
self.parent.assertEqual(len(result.past_key_values), config.n_layer)
def create_and_check_gpt2_model_past(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
model = GPT2Model(config=config)

View File

@ -120,7 +120,7 @@ class MBartEnroIntegrationTest(AbstractSeq2SeqIntegrationTest):
summary = torch.Tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]]).long().to(torch_device)
result = lm_model(input_ids=context, decoder_input_ids=summary, labels=summary)
expected_shape = (*summary.shape, config.vocab_size)
self.assertEqual(result["logits"].shape, expected_shape)
self.assertEqual(result.logits.shape, expected_shape)
@require_torch

View File

@ -141,9 +141,9 @@ class T5ModelTester:
decoder_attention_mask=decoder_attention_mask,
)
result = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
decoder_output = result["last_hidden_state"]
decoder_past = result["decoder_past_key_values"]
encoder_output = result["encoder_last_hidden_state"]
decoder_output = result.last_hidden_state
decoder_past = result.decoder_past_key_values
encoder_output = result.encoder_last_hidden_state
self.parent.assertEqual(encoder_output.size(), (self.batch_size, self.encoder_seq_length, self.hidden_size))
self.parent.assertEqual(decoder_output.size(), (self.batch_size, self.decoder_seq_length, self.hidden_size))

View File

@ -141,10 +141,8 @@ class TFAlbertModelTester:
result = model(input_ids)
self.parent.assertListEqual(
list(result["last_hidden_state"].shape), [self.batch_size, self.seq_length, self.hidden_size]
)
self.parent.assertListEqual(list(result["pooler_output"].shape), [self.batch_size, self.hidden_size])
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
def create_and_check_albert_for_pretraining(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
@ -153,10 +151,8 @@ class TFAlbertModelTester:
model = TFAlbertForPreTraining(config=config)
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
result = model(inputs)
self.parent.assertListEqual(
list(result["prediction_logits"].shape), [self.batch_size, self.seq_length, self.vocab_size]
)
self.parent.assertListEqual(list(result["sop_logits"].shape), [self.batch_size, self.num_labels])
self.parent.assertEqual(result.prediction_logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
self.parent.assertEqual(result.sop_logits.shape, (self.batch_size, self.num_labels))
def create_and_check_albert_for_masked_lm(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
@ -164,7 +160,7 @@ class TFAlbertModelTester:
model = TFAlbertForMaskedLM(config=config)
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
result = model(inputs)
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.seq_length, self.vocab_size])
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_albert_for_sequence_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
@ -173,7 +169,7 @@ class TFAlbertModelTester:
model = TFAlbertForSequenceClassification(config=config)
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
result = model(inputs)
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_labels])
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
def create_and_check_albert_for_question_answering(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
@ -181,8 +177,8 @@ class TFAlbertModelTester:
model = TFAlbertForQuestionAnswering(config=config)
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
result = model(inputs)
self.parent.assertListEqual(list(result["start_logits"].shape), [self.batch_size, self.seq_length])
self.parent.assertListEqual(list(result["end_logits"].shape), [self.batch_size, self.seq_length])
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
def create_and_check_albert_for_multiple_choice(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels

View File

@ -135,10 +135,8 @@ class TFBertModelTester:
result = model(input_ids)
self.parent.assertListEqual(
list(result["last_hidden_state"].shape), [self.batch_size, self.seq_length, self.hidden_size]
)
self.parent.assertListEqual(list(result["pooler_output"].shape), [self.batch_size, self.hidden_size])
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
def create_and_check_bert_lm_head(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
@ -165,7 +163,7 @@ class TFBertModelTester:
"token_type_ids": token_type_ids,
}
result = model(inputs)
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.seq_length, self.vocab_size])
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_bert_for_next_sequence_prediction(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
@ -173,7 +171,7 @@ class TFBertModelTester:
model = TFBertForNextSentencePrediction(config=config)
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
result = model(inputs)
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, 2])
self.parent.assertEqual(result.logits.shape, (self.batch_size, 2))
def create_and_check_bert_for_pretraining(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
@ -181,10 +179,8 @@ class TFBertModelTester:
model = TFBertForPreTraining(config=config)
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
result = model(inputs)
self.parent.assertListEqual(
list(result["prediction_logits"].shape), [self.batch_size, self.seq_length, self.vocab_size]
)
self.parent.assertListEqual(list(result["seq_relationship_logits"].shape), [self.batch_size, 2])
self.parent.assertEqual(result.prediction_logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
self.parent.assertEqual(result.seq_relationship_logits.shape, (self.batch_size, 2))
def create_and_check_bert_for_sequence_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
@ -198,7 +194,7 @@ class TFBertModelTester:
}
result = model(inputs)
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_labels])
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
def create_and_check_bert_for_multiple_choice(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
@ -214,7 +210,7 @@ class TFBertModelTester:
"token_type_ids": multiple_choice_token_type_ids,
}
result = model(inputs)
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_choices])
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
def create_and_check_bert_for_token_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
@ -227,7 +223,7 @@ class TFBertModelTester:
"token_type_ids": token_type_ids,
}
result = model(inputs)
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.seq_length, self.num_labels])
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
def create_and_check_bert_for_question_answering(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
@ -240,8 +236,8 @@ class TFBertModelTester:
}
result = model(inputs)
self.parent.assertListEqual(list(result["start_logits"].shape), [self.batch_size, self.seq_length])
self.parent.assertListEqual(list(result["end_logits"].shape), [self.batch_size, self.seq_length])
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()

View File

@ -119,15 +119,13 @@ class TFCTRLModelTester(object):
result = model(input_ids)
self.parent.assertListEqual(
list(result["last_hidden_state"].shape), [self.batch_size, self.seq_length, self.hidden_size]
)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
def create_and_check_ctrl_lm_head(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
model = TFCTRLLMHeadModel(config=config)
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
result = model(inputs)
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.seq_length, self.vocab_size])
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()

View File

@ -106,9 +106,7 @@ class TFDistilBertModelTester:
result = model(inputs)
self.parent.assertListEqual(
list(result["last_hidden_state"].shape), [self.batch_size, self.seq_length, self.hidden_size]
)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
def create_and_check_distilbert_for_masked_lm(
self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
@ -116,7 +114,7 @@ class TFDistilBertModelTester:
model = TFDistilBertForMaskedLM(config=config)
inputs = {"input_ids": input_ids, "attention_mask": input_mask}
result = model(inputs)
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.seq_length, self.vocab_size])
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_distilbert_for_question_answering(
self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
@ -127,8 +125,8 @@ class TFDistilBertModelTester:
"attention_mask": input_mask,
}
result = model(inputs)
self.parent.assertListEqual(list(result["start_logits"].shape), [self.batch_size, self.seq_length])
self.parent.assertListEqual(list(result["end_logits"].shape), [self.batch_size, self.seq_length])
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
def create_and_check_distilbert_for_sequence_classification(
self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
@ -137,7 +135,7 @@ class TFDistilBertModelTester:
model = TFDistilBertForSequenceClassification(config)
inputs = {"input_ids": input_ids, "attention_mask": input_mask}
result = model(inputs)
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_labels])
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
def create_and_check_distilbert_for_multiple_choice(
self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
@ -151,7 +149,7 @@ class TFDistilBertModelTester:
"attention_mask": multiple_choice_input_mask,
}
result = model(inputs)
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_choices])
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
def create_and_check_distilbert_for_token_classification(
self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
@ -160,7 +158,7 @@ class TFDistilBertModelTester:
model = TFDistilBertForTokenClassification(config)
inputs = {"input_ids": input_ids, "attention_mask": input_mask}
result = model(inputs)
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.seq_length, self.num_labels])
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()

View File

@ -113,9 +113,7 @@ class TFElectraModelTester:
result = model(input_ids)
self.parent.assertListEqual(
list(result["last_hidden_state"].shape), [self.batch_size, self.seq_length, self.hidden_size]
)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
def create_and_check_electra_for_masked_lm(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
@ -123,7 +121,7 @@ class TFElectraModelTester:
model = TFElectraForMaskedLM(config=config)
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
result = model(inputs)
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.seq_length, self.vocab_size])
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_electra_for_pretraining(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
@ -131,7 +129,7 @@ class TFElectraModelTester:
model = TFElectraForPreTraining(config=config)
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
result = model(inputs)
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.seq_length])
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length))
def create_and_check_electra_for_sequence_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
@ -140,7 +138,7 @@ class TFElectraModelTester:
model = TFElectraForSequenceClassification(config=config)
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
result = model(inputs)
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_labels])
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
def create_and_check_electra_for_multiple_choice(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
@ -156,7 +154,7 @@ class TFElectraModelTester:
"token_type_ids": multiple_choice_token_type_ids,
}
result = model(inputs)
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_choices])
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
def create_and_check_electra_for_question_answering(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
@ -164,8 +162,8 @@ class TFElectraModelTester:
model = TFElectraForQuestionAnswering(config=config)
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
result = model(inputs)
self.parent.assertListEqual(list(result["start_logits"].shape), [self.batch_size, self.seq_length])
self.parent.assertListEqual(list(result["end_logits"].shape), [self.batch_size, self.seq_length])
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
def create_and_check_electra_for_token_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
@ -174,7 +172,7 @@ class TFElectraModelTester:
model = TFElectraForTokenClassification(config=config)
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
result = model(inputs)
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.seq_length, self.num_labels])
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()

View File

@ -146,9 +146,7 @@ class TFFlaubertModelTester:
inputs = [input_ids, input_mask]
result = model(inputs)
self.parent.assertListEqual(
list(result["last_hidden_state"].shape), [self.batch_size, self.seq_length, self.hidden_size]
)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
def create_and_check_flaubert_lm_head(
self,
@ -167,7 +165,7 @@ class TFFlaubertModelTester:
inputs = {"input_ids": input_ids, "lengths": input_lengths, "langs": token_type_ids}
result = model(inputs)
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.seq_length, self.vocab_size])
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_flaubert_qa(
self,
@ -187,8 +185,8 @@ class TFFlaubertModelTester:
result = model(inputs)
self.parent.assertListEqual(list(result["start_logits"].shape), [self.batch_size, self.seq_length])
self.parent.assertListEqual(list(result["end_logits"].shape), [self.batch_size, self.seq_length])
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
def create_and_check_flaubert_sequence_classif(
self,
@ -208,7 +206,7 @@ class TFFlaubertModelTester:
result = model(inputs)
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.type_sequence_label_size])
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
def create_and_check_flaubert_for_token_classification(
self,
@ -226,7 +224,7 @@ class TFFlaubertModelTester:
model = TFFlaubertForTokenClassification(config=config)
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
result = model(inputs)
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.seq_length, self.num_labels])
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
def create_and_check_flaubert_for_multiple_choice(
self,
@ -251,7 +249,7 @@ class TFFlaubertModelTester:
"token_type_ids": multiple_choice_token_type_ids,
}
result = model(inputs)
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_choices])
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()

View File

@ -133,9 +133,7 @@ class TFGPT2ModelTester:
result = model(input_ids)
self.parent.assertListEqual(
list(result["last_hidden_state"].shape), [self.batch_size, self.seq_length, self.hidden_size],
)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
def create_and_check_gpt2_model_past(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
model = TFGPT2Model(config=config)
@ -219,9 +217,7 @@ class TFGPT2ModelTester:
"token_type_ids": token_type_ids,
}
result = model(inputs)
self.parent.assertListEqual(
list(result["logits"].shape), [self.batch_size, self.seq_length, self.vocab_size],
)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_gpt2_double_head(
self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, *args
@ -239,10 +235,10 @@ class TFGPT2ModelTester:
"token_type_ids": multiple_choice_token_type_ids,
}
result = model(inputs)
self.parent.assertListEqual(
list(result["lm_logits"].shape), [self.batch_size, self.num_choices, self.seq_length, self.vocab_size],
self.parent.assertEqual(
result.lm_logits.shape, (self.batch_size, self.num_choices, self.seq_length, self.vocab_size)
)
self.parent.assertListEqual(list(result["mc_logits"].shape), [self.batch_size, self.num_choices])
self.parent.assertEqual(result.mc_logits.shape, (self.batch_size, self.num_choices))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()

View File

@ -155,10 +155,10 @@ class TFMobileBertModelTest(TFModelTesterMixin, unittest.TestCase):
result = model(input_ids)
self.parent.assertListEqual(
list(result["last_hidden_state"].shape), [self.batch_size, self.seq_length, self.hidden_size]
self.parent.assertEqual(
result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)
)
self.parent.assertListEqual(list(result["pooler_output"].shape), [self.batch_size, self.hidden_size])
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
def create_and_check_mobilebert_for_masked_lm(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
@ -166,9 +166,7 @@ class TFMobileBertModelTest(TFModelTesterMixin, unittest.TestCase):
model = TFMobileBertForMaskedLM(config=config)
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
result = model(inputs)
self.parent.assertListEqual(
list(result["logits"].shape), [self.batch_size, self.seq_length, self.vocab_size]
)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_mobilebert_for_next_sequence_prediction(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
@ -176,7 +174,7 @@ class TFMobileBertModelTest(TFModelTesterMixin, unittest.TestCase):
model = TFMobileBertForNextSentencePrediction(config=config)
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
result = model(inputs)
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, 2])
self.parent.assertEqual(result.logits.shape, (self.batch_size, 2))
def create_and_check_mobilebert_for_pretraining(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
@ -184,10 +182,10 @@ class TFMobileBertModelTest(TFModelTesterMixin, unittest.TestCase):
model = TFMobileBertForPreTraining(config=config)
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
result = model(inputs)
self.parent.assertListEqual(
list(result["prediction_logits"].shape), [self.batch_size, self.seq_length, self.vocab_size]
self.parent.assertEqual(
result.prediction_logits.shape, (self.batch_size, self.seq_length, self.vocab_size)
)
self.parent.assertListEqual(list(result["seq_relationship_logits"].shape), [self.batch_size, 2])
self.parent.assertEqual(result.seq_relationship_logits.shape, (self.batch_size, 2))
def create_and_check_mobilebert_for_sequence_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
@ -196,7 +194,7 @@ class TFMobileBertModelTest(TFModelTesterMixin, unittest.TestCase):
model = TFMobileBertForSequenceClassification(config=config)
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
result = model(inputs)
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_labels])
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
def create_and_check_mobilebert_for_multiple_choice(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
@ -212,7 +210,7 @@ class TFMobileBertModelTest(TFModelTesterMixin, unittest.TestCase):
"token_type_ids": multiple_choice_token_type_ids,
}
result = model(inputs)
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_choices])
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
def create_and_check_mobilebert_for_token_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
@ -221,9 +219,7 @@ class TFMobileBertModelTest(TFModelTesterMixin, unittest.TestCase):
model = TFMobileBertForTokenClassification(config=config)
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
result = model(inputs)
self.parent.assertListEqual(
list(result["logits"].shape), [self.batch_size, self.seq_length, self.num_labels]
)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
def create_and_check_mobilebert_for_question_answering(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
@ -231,8 +227,8 @@ class TFMobileBertModelTest(TFModelTesterMixin, unittest.TestCase):
model = TFMobileBertForQuestionAnswering(config=config)
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
result = model(inputs)
self.parent.assertListEqual(list(result["start_logits"].shape), [self.batch_size, self.seq_length])
self.parent.assertListEqual(list(result["end_logits"].shape), [self.batch_size, self.seq_length])
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()

View File

@ -124,15 +124,13 @@ class TFOpenAIGPTModelTester:
result = model(input_ids)
self.parent.assertListEqual(
list(result["last_hidden_state"].shape), [self.batch_size, self.seq_length, self.hidden_size]
)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
def create_and_check_openai_gpt_lm_head(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
model = TFOpenAIGPTLMHeadModel(config=config)
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
result = model(inputs)
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.seq_length, self.vocab_size])
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_openai_gpt_double_head(
self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, *args
@ -150,10 +148,10 @@ class TFOpenAIGPTModelTester:
"token_type_ids": multiple_choice_token_type_ids,
}
result = model(inputs)
self.parent.assertListEqual(
list(result["lm_logits"].shape), [self.batch_size, self.num_choices, self.seq_length, self.vocab_size]
self.parent.assertEqual(
result.lm_logits.shape, (self.batch_size, self.num_choices, self.seq_length, self.vocab_size)
)
self.parent.assertListEqual(list(result["mc_logits"].shape), [self.batch_size, self.num_choices])
self.parent.assertEqual(result.mc_logits.shape, (self.batch_size, self.num_choices))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()

View File

@ -112,16 +112,14 @@ class TFRobertaModelTester:
result = model(input_ids)
self.parent.assertListEqual(
list(result["last_hidden_state"].shape), [self.batch_size, self.seq_length, self.hidden_size]
)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
def create_and_check_roberta_for_masked_lm(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = TFRobertaForMaskedLM(config=config)
result = model([input_ids, input_mask, token_type_ids])
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.seq_length, self.vocab_size])
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_roberta_for_token_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
@ -130,7 +128,7 @@ class TFRobertaModelTester:
model = TFRobertaForTokenClassification(config=config)
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
result = model(inputs)
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.seq_length, self.num_labels])
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
def create_and_check_roberta_for_question_answering(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
@ -138,8 +136,8 @@ class TFRobertaModelTester:
model = TFRobertaForQuestionAnswering(config=config)
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
result = model(inputs)
self.parent.assertListEqual(list(result["start_logits"].shape), [self.batch_size, self.seq_length])
self.parent.assertListEqual(list(result["end_logits"].shape), [self.batch_size, self.seq_length])
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
def create_and_check_roberta_for_multiple_choice(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
@ -155,7 +153,7 @@ class TFRobertaModelTester:
"token_type_ids": multiple_choice_token_type_ids,
}
result = model(inputs)
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_choices])
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()

View File

@ -93,9 +93,9 @@ class TFT5ModelTester:
result = model(inputs)
result = model(input_ids, decoder_attention_mask=input_mask, decoder_input_ids=input_ids)
decoder_output = result["last_hidden_state"]
decoder_past = result["decoder_past_key_values"]
encoder_output = result["encoder_last_hidden_state"]
decoder_output = result.last_hidden_state
decoder_past = result.decoder_past_key_values
encoder_output = result.encoder_last_hidden_state
self.parent.assertListEqual(list(encoder_output.shape), [self.batch_size, self.seq_length, self.hidden_size])
self.parent.assertListEqual(list(decoder_output.shape), [self.batch_size, self.seq_length, self.hidden_size])
self.parent.assertEqual(len(decoder_past), 2)
@ -116,7 +116,7 @@ class TFT5ModelTester:
result = model(inputs_dict)
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.seq_length, self.vocab_size])
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_t5_decoder_model_past(self, config, input_ids, decoder_input_ids, attention_mask):
model = TFT5Model(config=config).get_decoder()

View File

@ -97,26 +97,15 @@ class TFTransfoXLModelTester:
hidden_states_2, mems_2 = model(inputs).to_tuple()
result = {
"hidden_states_1": hidden_states_1.numpy(),
"mems_1": [mem.numpy() for mem in mems_1],
"hidden_states_2": hidden_states_2.numpy(),
"mems_2": [mem.numpy() for mem in mems_2],
}
self.parent.assertEqual(hidden_states_1.shape, (self.batch_size, self.seq_length, self.hidden_size))
self.parent.assertEqual(hidden_states_2.shape, (self.batch_size, self.seq_length, self.hidden_size))
self.parent.assertListEqual(
list(result["hidden_states_1"].shape), [self.batch_size, self.seq_length, self.hidden_size]
[mem.shape for mem in mems_1],
[(self.mem_len, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
)
self.parent.assertListEqual(
list(result["hidden_states_2"].shape), [self.batch_size, self.seq_length, self.hidden_size]
)
self.parent.assertListEqual(
list(list(mem.shape) for mem in result["mems_1"]),
[[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
)
self.parent.assertListEqual(
list(list(mem.shape) for mem in result["mems_2"]),
[[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
[mem.shape for mem in mems_2],
[(self.mem_len, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
)
def create_and_check_transfo_xl_lm_head(self, config, input_ids_1, input_ids_2, lm_labels):
@ -133,27 +122,16 @@ class TFTransfoXLModelTester:
_, mems_2 = model(inputs).to_tuple()
result = {
"mems_1": [mem.numpy() for mem in mems_1],
"lm_logits_1": lm_logits_1.numpy(),
"mems_2": [mem.numpy() for mem in mems_2],
"lm_logits_2": lm_logits_2.numpy(),
}
self.parent.assertEqual(lm_logits_1.shape, (self.batch_size, self.seq_length, self.vocab_size))
self.parent.assertListEqual(
list(result["lm_logits_1"].shape), [self.batch_size, self.seq_length, self.vocab_size]
)
self.parent.assertListEqual(
list(list(mem.shape) for mem in result["mems_1"]),
[[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
[mem.shape for mem in mems_1],
[(self.mem_len, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
)
self.parent.assertEqual(lm_logits_2.shape, (self.batch_size, self.seq_length, self.vocab_size))
self.parent.assertListEqual(
list(result["lm_logits_2"].shape), [self.batch_size, self.seq_length, self.vocab_size]
)
self.parent.assertListEqual(
list(list(mem.shape) for mem in result["mems_2"]),
[[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
[mem.shape for mem in mems_2],
[(self.mem_len, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
)
def prepare_config_and_inputs_for_common(self):

View File

@ -145,9 +145,7 @@ class TFXLMModelTester:
inputs = [input_ids, input_mask]
result = model(inputs)
self.parent.assertListEqual(
list(result["last_hidden_state"].shape), [self.batch_size, self.seq_length, self.hidden_size]
)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
def create_and_check_xlm_lm_head(
self,
@ -168,7 +166,7 @@ class TFXLMModelTester:
result = outputs
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.seq_length, self.vocab_size])
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_xlm_qa(
self,
@ -188,8 +186,8 @@ class TFXLMModelTester:
result = model(inputs)
self.parent.assertListEqual(list(result["start_logits"].shape), [self.batch_size, self.seq_length])
self.parent.assertListEqual(list(result["end_logits"].shape), [self.batch_size, self.seq_length])
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
def create_and_check_xlm_sequence_classif(
self,
@ -209,7 +207,7 @@ class TFXLMModelTester:
result = model(inputs)
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.type_sequence_label_size])
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
def create_and_check_xlm_for_token_classification(
self,
@ -227,7 +225,7 @@ class TFXLMModelTester:
model = TFXLMForTokenClassification(config=config)
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
result = model(inputs)
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.seq_length, self.num_labels])
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
def create_and_check_xlm_for_multiple_choice(
self,
@ -252,7 +250,7 @@ class TFXLMModelTester:
"token_type_ids": multiple_choice_token_type_ids,
}
result = model(inputs)
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_choices])
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()

View File

@ -158,12 +158,10 @@ class TFXLNetModelTester:
no_mems_outputs = model(inputs)
self.parent.assertEqual(len(no_mems_outputs), 1)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
self.parent.assertListEqual(
list(result["last_hidden_state"].shape), [self.batch_size, self.seq_length, self.hidden_size]
)
self.parent.assertListEqual(
list(list(mem.shape) for mem in result["mems"]),
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
[mem.shape for mem in result.mems],
[(self.seq_length, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
)
def create_and_check_xlnet_lm_head(
@ -191,27 +189,15 @@ class TFXLNetModelTester:
inputs_3 = {"input_ids": input_ids_q, "perm_mask": perm_mask, "target_mapping": target_mapping}
logits, _ = model(inputs_3).to_tuple()
result = {
"mems_1": [mem.numpy() for mem in mems_1],
"all_logits_1": all_logits_1.numpy(),
"mems_2": [mem.numpy() for mem in mems_2],
"all_logits_2": all_logits_2.numpy(),
}
self.parent.assertEqual(all_logits_1.shape, (self.batch_size, self.seq_length, self.vocab_size))
self.parent.assertListEqual(
list(result["all_logits_1"].shape), [self.batch_size, self.seq_length, self.vocab_size]
[mem.shape for mem in mems_1],
[(self.seq_length, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
)
self.parent.assertEqual(all_logits_2.shape, (self.batch_size, self.seq_length, self.vocab_size))
self.parent.assertListEqual(
list(list(mem.shape) for mem in result["mems_1"]),
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
)
self.parent.assertListEqual(
list(result["all_logits_2"].shape), [self.batch_size, self.seq_length, self.vocab_size]
)
self.parent.assertListEqual(
list(list(mem.shape) for mem in result["mems_2"]),
[[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
[mem.shape for mem in mems_2],
[(self.mem_len, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
)
def create_and_check_xlnet_qa(
@ -233,11 +219,11 @@ class TFXLNetModelTester:
inputs = {"input_ids": input_ids_1, "attention_mask": input_mask, "token_type_ids": segment_ids}
result = model(inputs)
self.parent.assertListEqual(list(result["start_logits"].shape), [self.batch_size, self.seq_length])
self.parent.assertListEqual(list(result["end_logits"].shape), [self.batch_size, self.seq_length])
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
self.parent.assertListEqual(
list(list(mem.shape) for mem in result["mems"]),
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
[mem.shape for mem in result.mems],
[(self.seq_length, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
)
def create_and_check_xlnet_sequence_classif(
@ -258,10 +244,10 @@ class TFXLNetModelTester:
result = model(input_ids_1)
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.type_sequence_label_size])
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
self.parent.assertListEqual(
list(list(mem.shape) for mem in result["mems"]),
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
[mem.shape for mem in result.mems],
[(self.seq_length, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
)
def create_and_check_xlnet_for_token_classification(
@ -286,12 +272,10 @@ class TFXLNetModelTester:
# 'token_type_ids': token_type_ids
}
result = model(inputs)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, config.num_labels))
self.parent.assertListEqual(
list(result["logits"].shape), [self.batch_size, self.seq_length, config.num_labels]
)
self.parent.assertListEqual(
list(list(mem.shape) for mem in result["mems"]),
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
[mem.shape for mem in result.mems],
[(self.seq_length, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
)
def create_and_check_xlnet_for_multiple_choice(
@ -320,10 +304,10 @@ class TFXLNetModelTester:
}
result = model(inputs)
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_choices])
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
self.parent.assertListEqual(
list(list(mem.shape) for mem in result["mems"]),
[[self.seq_length, self.batch_size * self.num_choices, self.hidden_size]] * self.num_hidden_layers,
[mem.shape for mem in result.mems],
[(self.seq_length, self.batch_size * self.num_choices, self.hidden_size)] * self.num_hidden_layers,
)
def prepare_config_and_inputs_for_common(self):

View File

@ -100,19 +100,15 @@ class TransfoXLModelTester:
return outputs
def check_transfo_xl_model_output(self, result):
self.parent.assertEqual(result["hidden_states_1"].shape, (self.batch_size, self.seq_length, self.hidden_size))
self.parent.assertEqual(result["hidden_states_2"].shape, (self.batch_size, self.seq_length, self.hidden_size))
self.parent.assertListEqual(
list(result["hidden_states_1"].size()), [self.batch_size, self.seq_length, self.hidden_size],
[mem.shape for mem in result["mems_1"]],
[(self.mem_len, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
)
self.parent.assertListEqual(
list(result["hidden_states_2"].size()), [self.batch_size, self.seq_length, self.hidden_size],
)
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1"]),
[[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
)
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_2"]),
[[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
[mem.shape for mem in result["mems_2"]],
[(self.mem_len, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
)
def create_transfo_xl_lm_head(self, config, input_ids_1, input_ids_2, lm_labels):
@ -136,22 +132,18 @@ class TransfoXLModelTester:
return outputs
def check_transfo_xl_lm_head_output(self, result):
self.parent.assertListEqual(list(result["loss_1"].size()), [self.batch_size, self.seq_length - 1])
self.parent.assertEqual(result["loss_1"].shape, (self.batch_size, self.seq_length - 1))
self.parent.assertEqual(result["lm_logits_1"].shape, (self.batch_size, self.seq_length, self.vocab_size))
self.parent.assertListEqual(
list(result["lm_logits_1"].size()), [self.batch_size, self.seq_length, self.vocab_size],
)
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1"]),
[[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
[mem.shape for mem in result["mems_1"]],
[(self.mem_len, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
)
self.parent.assertListEqual(list(result["loss_2"].size()), [self.batch_size, self.seq_length - 1])
self.parent.assertEqual(result["loss_2"].shape, (self.batch_size, self.seq_length - 1))
self.parent.assertEqual(result["lm_logits_2"].shape, (self.batch_size, self.seq_length, self.vocab_size))
self.parent.assertListEqual(
list(result["lm_logits_2"].size()), [self.batch_size, self.seq_length, self.vocab_size],
)
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_2"]),
[[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
[mem.shape for mem in result["mems_2"]],
[(self.mem_len, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
)
def prepare_config_and_inputs_for_common(self):

View File

@ -192,8 +192,8 @@ class XLNetModelTester:
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems"]),
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
[mem.shape for mem in result.mems],
[(self.seq_length, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
)
def create_and_check_xlnet_model_use_cache(
@ -305,22 +305,22 @@ class XLNetModelTester:
result1 = model(input_ids_1, token_type_ids=segment_ids, labels=lm_labels)
result2 = model(input_ids_2, token_type_ids=segment_ids, labels=lm_labels, mems=result1["mems"])
result2 = model(input_ids_2, token_type_ids=segment_ids, labels=lm_labels, mems=result1.mems)
_ = model(input_ids_q, perm_mask=perm_mask, target_mapping=target_mapping)
self.parent.assertEqual(result1.loss.shape, ())
self.parent.assertEqual(result1.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
self.parent.assertListEqual(
list(list(mem.size()) for mem in result1["mems"]),
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
[mem.shape for mem in result1.mems],
[(self.seq_length, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
)
self.parent.assertEqual(result2.loss.shape, ())
self.parent.assertEqual(result2.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
self.parent.assertListEqual(
list(list(mem.size()) for mem in result2["mems"]),
[[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
[mem.shape for mem in result2.mems],
[(self.mem_len, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
)
def create_and_check_xlnet_qa(
@ -378,8 +378,8 @@ class XLNetModelTester:
)
self.parent.assertEqual(result.cls_logits.shape, (self.batch_size,))
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems"]),
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
[mem.shape for mem in result.mems],
[(self.seq_length, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
)
def create_and_check_xlnet_token_classif(
@ -407,8 +407,8 @@ class XLNetModelTester:
self.parent.assertEqual(result.loss.shape, ())
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.type_sequence_label_size))
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems"]),
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
[mem.shape for mem in result.mems],
[(self.seq_length, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
)
def create_and_check_xlnet_sequence_classif(
@ -436,8 +436,8 @@ class XLNetModelTester:
self.parent.assertEqual(result.loss.shape, ())
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems"]),
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
[mem.shape for mem in result.mems],
[(self.seq_length, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
)
def prepare_config_and_inputs_for_common(self):