cleanup tf unittests: part 2 (#6260)
* cleanup torch unittests: part 2 * remove trailing comma added by isort, and which breaks flake * one more comma * revert odd balls * part 3: odd cases * more ["key"] -> .key refactoring * .numpy() is not needed * more unncessary .numpy() removed * more simplification
This commit is contained in:
parent
bc820476a5
commit
e983da0e7d
|
@ -148,10 +148,10 @@ class TFXxxModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||
|
||||
result = model(input_ids)
|
||||
|
||||
self.parent.assertListEqual(
|
||||
list(result["last_hidden_state"].shape), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
self.parent.assertEqual(
|
||||
result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)
|
||||
)
|
||||
self.parent.assertListEqual(list(result["pooler_output"].shape), [self.batch_size, self.hidden_size])
|
||||
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
|
||||
|
||||
def create_and_check_xxx_for_masked_lm(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
@ -159,9 +159,7 @@ class TFXxxModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||
model = TFXxxForMaskedLM(config=config)
|
||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||
result = model(inputs)
|
||||
self.parent.assertListEqual(
|
||||
list(result["logits"].shape), [self.batch_size, self.seq_length, self.vocab_size]
|
||||
)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
|
||||
def create_and_check_xxx_for_sequence_classification(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
@ -170,7 +168,7 @@ class TFXxxModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||
model = TFXxxForSequenceClassification(config=config)
|
||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||
result = model(inputs)
|
||||
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_labels])
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
|
||||
|
||||
def create_and_check_bert_for_multiple_choice(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
@ -186,7 +184,7 @@ class TFXxxModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||
"token_type_ids": multiple_choice_token_type_ids,
|
||||
}
|
||||
result = model(inputs)
|
||||
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_choices])
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
|
||||
|
||||
def create_and_check_xxx_for_token_classification(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
@ -195,9 +193,7 @@ class TFXxxModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||
model = TFXxxForTokenClassification(config=config)
|
||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||
result = model(inputs)
|
||||
self.parent.assertListEqual(
|
||||
list(result["logits"].shape), [self.batch_size, self.seq_length, self.num_labels]
|
||||
)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
|
||||
|
||||
def create_and_check_xxx_for_question_answering(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
@ -205,8 +201,8 @@ class TFXxxModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||
model = TFXxxForQuestionAnswering(config=config)
|
||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||
result = model(inputs)
|
||||
self.parent.assertListEqual(list(result["start_logits"].shape), [self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(list(result["end_logits"].shape), [self.batch_size, self.seq_length])
|
||||
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
|
||||
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
|
|
|
@ -117,7 +117,7 @@ class CTRLModelTester:
|
|||
model(input_ids, token_type_ids=token_type_ids)
|
||||
result = model(input_ids)
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
self.parent.assertEqual(len(result["past_key_values"]), config.n_layer)
|
||||
self.parent.assertEqual(len(result.past_key_values), config.n_layer)
|
||||
|
||||
def create_and_check_lm_head_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
|
||||
model = CTRLLMHeadModel(config)
|
||||
|
|
|
@ -152,7 +152,7 @@ class GPT2ModelTester:
|
|||
result = model(input_ids)
|
||||
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
self.parent.assertEqual(len(result["past_key_values"]), config.n_layer)
|
||||
self.parent.assertEqual(len(result.past_key_values), config.n_layer)
|
||||
|
||||
def create_and_check_gpt2_model_past(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
|
||||
model = GPT2Model(config=config)
|
||||
|
|
|
@ -120,7 +120,7 @@ class MBartEnroIntegrationTest(AbstractSeq2SeqIntegrationTest):
|
|||
summary = torch.Tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]]).long().to(torch_device)
|
||||
result = lm_model(input_ids=context, decoder_input_ids=summary, labels=summary)
|
||||
expected_shape = (*summary.shape, config.vocab_size)
|
||||
self.assertEqual(result["logits"].shape, expected_shape)
|
||||
self.assertEqual(result.logits.shape, expected_shape)
|
||||
|
||||
|
||||
@require_torch
|
||||
|
|
|
@ -141,9 +141,9 @@ class T5ModelTester:
|
|||
decoder_attention_mask=decoder_attention_mask,
|
||||
)
|
||||
result = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
|
||||
decoder_output = result["last_hidden_state"]
|
||||
decoder_past = result["decoder_past_key_values"]
|
||||
encoder_output = result["encoder_last_hidden_state"]
|
||||
decoder_output = result.last_hidden_state
|
||||
decoder_past = result.decoder_past_key_values
|
||||
encoder_output = result.encoder_last_hidden_state
|
||||
|
||||
self.parent.assertEqual(encoder_output.size(), (self.batch_size, self.encoder_seq_length, self.hidden_size))
|
||||
self.parent.assertEqual(decoder_output.size(), (self.batch_size, self.decoder_seq_length, self.hidden_size))
|
||||
|
|
|
@ -141,10 +141,8 @@ class TFAlbertModelTester:
|
|||
|
||||
result = model(input_ids)
|
||||
|
||||
self.parent.assertListEqual(
|
||||
list(result["last_hidden_state"].shape), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
)
|
||||
self.parent.assertListEqual(list(result["pooler_output"].shape), [self.batch_size, self.hidden_size])
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
|
||||
|
||||
def create_and_check_albert_for_pretraining(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
@ -153,10 +151,8 @@ class TFAlbertModelTester:
|
|||
model = TFAlbertForPreTraining(config=config)
|
||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||
result = model(inputs)
|
||||
self.parent.assertListEqual(
|
||||
list(result["prediction_logits"].shape), [self.batch_size, self.seq_length, self.vocab_size]
|
||||
)
|
||||
self.parent.assertListEqual(list(result["sop_logits"].shape), [self.batch_size, self.num_labels])
|
||||
self.parent.assertEqual(result.prediction_logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
self.parent.assertEqual(result.sop_logits.shape, (self.batch_size, self.num_labels))
|
||||
|
||||
def create_and_check_albert_for_masked_lm(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
@ -164,7 +160,7 @@ class TFAlbertModelTester:
|
|||
model = TFAlbertForMaskedLM(config=config)
|
||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||
result = model(inputs)
|
||||
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.seq_length, self.vocab_size])
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
|
||||
def create_and_check_albert_for_sequence_classification(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
@ -173,7 +169,7 @@ class TFAlbertModelTester:
|
|||
model = TFAlbertForSequenceClassification(config=config)
|
||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||
result = model(inputs)
|
||||
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_labels])
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
|
||||
|
||||
def create_and_check_albert_for_question_answering(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
@ -181,8 +177,8 @@ class TFAlbertModelTester:
|
|||
model = TFAlbertForQuestionAnswering(config=config)
|
||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||
result = model(inputs)
|
||||
self.parent.assertListEqual(list(result["start_logits"].shape), [self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(list(result["end_logits"].shape), [self.batch_size, self.seq_length])
|
||||
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
|
||||
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
|
||||
|
||||
def create_and_check_albert_for_multiple_choice(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
|
|
@ -135,10 +135,8 @@ class TFBertModelTester:
|
|||
|
||||
result = model(input_ids)
|
||||
|
||||
self.parent.assertListEqual(
|
||||
list(result["last_hidden_state"].shape), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
)
|
||||
self.parent.assertListEqual(list(result["pooler_output"].shape), [self.batch_size, self.hidden_size])
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
|
||||
|
||||
def create_and_check_bert_lm_head(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
@ -165,7 +163,7 @@ class TFBertModelTester:
|
|||
"token_type_ids": token_type_ids,
|
||||
}
|
||||
result = model(inputs)
|
||||
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.seq_length, self.vocab_size])
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
|
||||
def create_and_check_bert_for_next_sequence_prediction(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
@ -173,7 +171,7 @@ class TFBertModelTester:
|
|||
model = TFBertForNextSentencePrediction(config=config)
|
||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||
result = model(inputs)
|
||||
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, 2])
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, 2))
|
||||
|
||||
def create_and_check_bert_for_pretraining(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
@ -181,10 +179,8 @@ class TFBertModelTester:
|
|||
model = TFBertForPreTraining(config=config)
|
||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||
result = model(inputs)
|
||||
self.parent.assertListEqual(
|
||||
list(result["prediction_logits"].shape), [self.batch_size, self.seq_length, self.vocab_size]
|
||||
)
|
||||
self.parent.assertListEqual(list(result["seq_relationship_logits"].shape), [self.batch_size, 2])
|
||||
self.parent.assertEqual(result.prediction_logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
self.parent.assertEqual(result.seq_relationship_logits.shape, (self.batch_size, 2))
|
||||
|
||||
def create_and_check_bert_for_sequence_classification(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
@ -198,7 +194,7 @@ class TFBertModelTester:
|
|||
}
|
||||
|
||||
result = model(inputs)
|
||||
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_labels])
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
|
||||
|
||||
def create_and_check_bert_for_multiple_choice(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
@ -214,7 +210,7 @@ class TFBertModelTester:
|
|||
"token_type_ids": multiple_choice_token_type_ids,
|
||||
}
|
||||
result = model(inputs)
|
||||
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_choices])
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
|
||||
|
||||
def create_and_check_bert_for_token_classification(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
@ -227,7 +223,7 @@ class TFBertModelTester:
|
|||
"token_type_ids": token_type_ids,
|
||||
}
|
||||
result = model(inputs)
|
||||
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.seq_length, self.num_labels])
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
|
||||
|
||||
def create_and_check_bert_for_question_answering(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
@ -240,8 +236,8 @@ class TFBertModelTester:
|
|||
}
|
||||
|
||||
result = model(inputs)
|
||||
self.parent.assertListEqual(list(result["start_logits"].shape), [self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(list(result["end_logits"].shape), [self.batch_size, self.seq_length])
|
||||
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
|
||||
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
|
|
|
@ -119,15 +119,13 @@ class TFCTRLModelTester(object):
|
|||
|
||||
result = model(input_ids)
|
||||
|
||||
self.parent.assertListEqual(
|
||||
list(result["last_hidden_state"].shape), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
)
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
|
||||
def create_and_check_ctrl_lm_head(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
|
||||
model = TFCTRLLMHeadModel(config=config)
|
||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||
result = model(inputs)
|
||||
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.seq_length, self.vocab_size])
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
|
|
|
@ -106,9 +106,7 @@ class TFDistilBertModelTester:
|
|||
|
||||
result = model(inputs)
|
||||
|
||||
self.parent.assertListEqual(
|
||||
list(result["last_hidden_state"].shape), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
)
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
|
||||
def create_and_check_distilbert_for_masked_lm(
|
||||
self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
@ -116,7 +114,7 @@ class TFDistilBertModelTester:
|
|||
model = TFDistilBertForMaskedLM(config=config)
|
||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask}
|
||||
result = model(inputs)
|
||||
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.seq_length, self.vocab_size])
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
|
||||
def create_and_check_distilbert_for_question_answering(
|
||||
self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
@ -127,8 +125,8 @@ class TFDistilBertModelTester:
|
|||
"attention_mask": input_mask,
|
||||
}
|
||||
result = model(inputs)
|
||||
self.parent.assertListEqual(list(result["start_logits"].shape), [self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(list(result["end_logits"].shape), [self.batch_size, self.seq_length])
|
||||
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
|
||||
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
|
||||
|
||||
def create_and_check_distilbert_for_sequence_classification(
|
||||
self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
@ -137,7 +135,7 @@ class TFDistilBertModelTester:
|
|||
model = TFDistilBertForSequenceClassification(config)
|
||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask}
|
||||
result = model(inputs)
|
||||
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_labels])
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
|
||||
|
||||
def create_and_check_distilbert_for_multiple_choice(
|
||||
self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
@ -151,7 +149,7 @@ class TFDistilBertModelTester:
|
|||
"attention_mask": multiple_choice_input_mask,
|
||||
}
|
||||
result = model(inputs)
|
||||
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_choices])
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
|
||||
|
||||
def create_and_check_distilbert_for_token_classification(
|
||||
self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
@ -160,7 +158,7 @@ class TFDistilBertModelTester:
|
|||
model = TFDistilBertForTokenClassification(config)
|
||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask}
|
||||
result = model(inputs)
|
||||
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.seq_length, self.num_labels])
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
|
|
|
@ -113,9 +113,7 @@ class TFElectraModelTester:
|
|||
|
||||
result = model(input_ids)
|
||||
|
||||
self.parent.assertListEqual(
|
||||
list(result["last_hidden_state"].shape), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
)
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
|
||||
def create_and_check_electra_for_masked_lm(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
@ -123,7 +121,7 @@ class TFElectraModelTester:
|
|||
model = TFElectraForMaskedLM(config=config)
|
||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||
result = model(inputs)
|
||||
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.seq_length, self.vocab_size])
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
|
||||
def create_and_check_electra_for_pretraining(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
@ -131,7 +129,7 @@ class TFElectraModelTester:
|
|||
model = TFElectraForPreTraining(config=config)
|
||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||
result = model(inputs)
|
||||
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.seq_length])
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length))
|
||||
|
||||
def create_and_check_electra_for_sequence_classification(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
@ -140,7 +138,7 @@ class TFElectraModelTester:
|
|||
model = TFElectraForSequenceClassification(config=config)
|
||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||
result = model(inputs)
|
||||
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_labels])
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
|
||||
|
||||
def create_and_check_electra_for_multiple_choice(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
@ -156,7 +154,7 @@ class TFElectraModelTester:
|
|||
"token_type_ids": multiple_choice_token_type_ids,
|
||||
}
|
||||
result = model(inputs)
|
||||
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_choices])
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
|
||||
|
||||
def create_and_check_electra_for_question_answering(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
@ -164,8 +162,8 @@ class TFElectraModelTester:
|
|||
model = TFElectraForQuestionAnswering(config=config)
|
||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||
result = model(inputs)
|
||||
self.parent.assertListEqual(list(result["start_logits"].shape), [self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(list(result["end_logits"].shape), [self.batch_size, self.seq_length])
|
||||
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
|
||||
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
|
||||
|
||||
def create_and_check_electra_for_token_classification(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
@ -174,7 +172,7 @@ class TFElectraModelTester:
|
|||
model = TFElectraForTokenClassification(config=config)
|
||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||
result = model(inputs)
|
||||
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.seq_length, self.num_labels])
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
|
|
|
@ -146,9 +146,7 @@ class TFFlaubertModelTester:
|
|||
|
||||
inputs = [input_ids, input_mask]
|
||||
result = model(inputs)
|
||||
self.parent.assertListEqual(
|
||||
list(result["last_hidden_state"].shape), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
)
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
|
||||
def create_and_check_flaubert_lm_head(
|
||||
self,
|
||||
|
@ -167,7 +165,7 @@ class TFFlaubertModelTester:
|
|||
inputs = {"input_ids": input_ids, "lengths": input_lengths, "langs": token_type_ids}
|
||||
result = model(inputs)
|
||||
|
||||
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.seq_length, self.vocab_size])
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
|
||||
def create_and_check_flaubert_qa(
|
||||
self,
|
||||
|
@ -187,8 +185,8 @@ class TFFlaubertModelTester:
|
|||
|
||||
result = model(inputs)
|
||||
|
||||
self.parent.assertListEqual(list(result["start_logits"].shape), [self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(list(result["end_logits"].shape), [self.batch_size, self.seq_length])
|
||||
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
|
||||
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
|
||||
|
||||
def create_and_check_flaubert_sequence_classif(
|
||||
self,
|
||||
|
@ -208,7 +206,7 @@ class TFFlaubertModelTester:
|
|||
|
||||
result = model(inputs)
|
||||
|
||||
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.type_sequence_label_size])
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
|
||||
|
||||
def create_and_check_flaubert_for_token_classification(
|
||||
self,
|
||||
|
@ -226,7 +224,7 @@ class TFFlaubertModelTester:
|
|||
model = TFFlaubertForTokenClassification(config=config)
|
||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||
result = model(inputs)
|
||||
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.seq_length, self.num_labels])
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
|
||||
|
||||
def create_and_check_flaubert_for_multiple_choice(
|
||||
self,
|
||||
|
@ -251,7 +249,7 @@ class TFFlaubertModelTester:
|
|||
"token_type_ids": multiple_choice_token_type_ids,
|
||||
}
|
||||
result = model(inputs)
|
||||
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_choices])
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
|
|
|
@ -133,9 +133,7 @@ class TFGPT2ModelTester:
|
|||
|
||||
result = model(input_ids)
|
||||
|
||||
self.parent.assertListEqual(
|
||||
list(result["last_hidden_state"].shape), [self.batch_size, self.seq_length, self.hidden_size],
|
||||
)
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
|
||||
def create_and_check_gpt2_model_past(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
|
||||
model = TFGPT2Model(config=config)
|
||||
|
@ -219,9 +217,7 @@ class TFGPT2ModelTester:
|
|||
"token_type_ids": token_type_ids,
|
||||
}
|
||||
result = model(inputs)
|
||||
self.parent.assertListEqual(
|
||||
list(result["logits"].shape), [self.batch_size, self.seq_length, self.vocab_size],
|
||||
)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
|
||||
def create_and_check_gpt2_double_head(
|
||||
self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, *args
|
||||
|
@ -239,10 +235,10 @@ class TFGPT2ModelTester:
|
|||
"token_type_ids": multiple_choice_token_type_ids,
|
||||
}
|
||||
result = model(inputs)
|
||||
self.parent.assertListEqual(
|
||||
list(result["lm_logits"].shape), [self.batch_size, self.num_choices, self.seq_length, self.vocab_size],
|
||||
self.parent.assertEqual(
|
||||
result.lm_logits.shape, (self.batch_size, self.num_choices, self.seq_length, self.vocab_size)
|
||||
)
|
||||
self.parent.assertListEqual(list(result["mc_logits"].shape), [self.batch_size, self.num_choices])
|
||||
self.parent.assertEqual(result.mc_logits.shape, (self.batch_size, self.num_choices))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
|
|
|
@ -155,10 +155,10 @@ class TFMobileBertModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||
|
||||
result = model(input_ids)
|
||||
|
||||
self.parent.assertListEqual(
|
||||
list(result["last_hidden_state"].shape), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
self.parent.assertEqual(
|
||||
result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)
|
||||
)
|
||||
self.parent.assertListEqual(list(result["pooler_output"].shape), [self.batch_size, self.hidden_size])
|
||||
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
|
||||
|
||||
def create_and_check_mobilebert_for_masked_lm(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
@ -166,9 +166,7 @@ class TFMobileBertModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||
model = TFMobileBertForMaskedLM(config=config)
|
||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||
result = model(inputs)
|
||||
self.parent.assertListEqual(
|
||||
list(result["logits"].shape), [self.batch_size, self.seq_length, self.vocab_size]
|
||||
)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
|
||||
def create_and_check_mobilebert_for_next_sequence_prediction(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
@ -176,7 +174,7 @@ class TFMobileBertModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||
model = TFMobileBertForNextSentencePrediction(config=config)
|
||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||
result = model(inputs)
|
||||
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, 2])
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, 2))
|
||||
|
||||
def create_and_check_mobilebert_for_pretraining(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
@ -184,10 +182,10 @@ class TFMobileBertModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||
model = TFMobileBertForPreTraining(config=config)
|
||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||
result = model(inputs)
|
||||
self.parent.assertListEqual(
|
||||
list(result["prediction_logits"].shape), [self.batch_size, self.seq_length, self.vocab_size]
|
||||
self.parent.assertEqual(
|
||||
result.prediction_logits.shape, (self.batch_size, self.seq_length, self.vocab_size)
|
||||
)
|
||||
self.parent.assertListEqual(list(result["seq_relationship_logits"].shape), [self.batch_size, 2])
|
||||
self.parent.assertEqual(result.seq_relationship_logits.shape, (self.batch_size, 2))
|
||||
|
||||
def create_and_check_mobilebert_for_sequence_classification(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
@ -196,7 +194,7 @@ class TFMobileBertModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||
model = TFMobileBertForSequenceClassification(config=config)
|
||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||
result = model(inputs)
|
||||
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_labels])
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
|
||||
|
||||
def create_and_check_mobilebert_for_multiple_choice(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
@ -212,7 +210,7 @@ class TFMobileBertModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||
"token_type_ids": multiple_choice_token_type_ids,
|
||||
}
|
||||
result = model(inputs)
|
||||
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_choices])
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
|
||||
|
||||
def create_and_check_mobilebert_for_token_classification(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
@ -221,9 +219,7 @@ class TFMobileBertModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||
model = TFMobileBertForTokenClassification(config=config)
|
||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||
result = model(inputs)
|
||||
self.parent.assertListEqual(
|
||||
list(result["logits"].shape), [self.batch_size, self.seq_length, self.num_labels]
|
||||
)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
|
||||
|
||||
def create_and_check_mobilebert_for_question_answering(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
@ -231,8 +227,8 @@ class TFMobileBertModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||
model = TFMobileBertForQuestionAnswering(config=config)
|
||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||
result = model(inputs)
|
||||
self.parent.assertListEqual(list(result["start_logits"].shape), [self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(list(result["end_logits"].shape), [self.batch_size, self.seq_length])
|
||||
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
|
||||
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
|
|
|
@ -124,15 +124,13 @@ class TFOpenAIGPTModelTester:
|
|||
|
||||
result = model(input_ids)
|
||||
|
||||
self.parent.assertListEqual(
|
||||
list(result["last_hidden_state"].shape), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
)
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
|
||||
def create_and_check_openai_gpt_lm_head(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
|
||||
model = TFOpenAIGPTLMHeadModel(config=config)
|
||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||
result = model(inputs)
|
||||
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.seq_length, self.vocab_size])
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
|
||||
def create_and_check_openai_gpt_double_head(
|
||||
self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, *args
|
||||
|
@ -150,10 +148,10 @@ class TFOpenAIGPTModelTester:
|
|||
"token_type_ids": multiple_choice_token_type_ids,
|
||||
}
|
||||
result = model(inputs)
|
||||
self.parent.assertListEqual(
|
||||
list(result["lm_logits"].shape), [self.batch_size, self.num_choices, self.seq_length, self.vocab_size]
|
||||
self.parent.assertEqual(
|
||||
result.lm_logits.shape, (self.batch_size, self.num_choices, self.seq_length, self.vocab_size)
|
||||
)
|
||||
self.parent.assertListEqual(list(result["mc_logits"].shape), [self.batch_size, self.num_choices])
|
||||
self.parent.assertEqual(result.mc_logits.shape, (self.batch_size, self.num_choices))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
|
|
|
@ -112,16 +112,14 @@ class TFRobertaModelTester:
|
|||
|
||||
result = model(input_ids)
|
||||
|
||||
self.parent.assertListEqual(
|
||||
list(result["last_hidden_state"].shape), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
)
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
|
||||
def create_and_check_roberta_for_masked_lm(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = TFRobertaForMaskedLM(config=config)
|
||||
result = model([input_ids, input_mask, token_type_ids])
|
||||
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.seq_length, self.vocab_size])
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
|
||||
def create_and_check_roberta_for_token_classification(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
@ -130,7 +128,7 @@ class TFRobertaModelTester:
|
|||
model = TFRobertaForTokenClassification(config=config)
|
||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||
result = model(inputs)
|
||||
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.seq_length, self.num_labels])
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
|
||||
|
||||
def create_and_check_roberta_for_question_answering(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
@ -138,8 +136,8 @@ class TFRobertaModelTester:
|
|||
model = TFRobertaForQuestionAnswering(config=config)
|
||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||
result = model(inputs)
|
||||
self.parent.assertListEqual(list(result["start_logits"].shape), [self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(list(result["end_logits"].shape), [self.batch_size, self.seq_length])
|
||||
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
|
||||
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
|
||||
|
||||
def create_and_check_roberta_for_multiple_choice(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
@ -155,7 +153,7 @@ class TFRobertaModelTester:
|
|||
"token_type_ids": multiple_choice_token_type_ids,
|
||||
}
|
||||
result = model(inputs)
|
||||
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_choices])
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
|
|
|
@ -93,9 +93,9 @@ class TFT5ModelTester:
|
|||
result = model(inputs)
|
||||
|
||||
result = model(input_ids, decoder_attention_mask=input_mask, decoder_input_ids=input_ids)
|
||||
decoder_output = result["last_hidden_state"]
|
||||
decoder_past = result["decoder_past_key_values"]
|
||||
encoder_output = result["encoder_last_hidden_state"]
|
||||
decoder_output = result.last_hidden_state
|
||||
decoder_past = result.decoder_past_key_values
|
||||
encoder_output = result.encoder_last_hidden_state
|
||||
self.parent.assertListEqual(list(encoder_output.shape), [self.batch_size, self.seq_length, self.hidden_size])
|
||||
self.parent.assertListEqual(list(decoder_output.shape), [self.batch_size, self.seq_length, self.hidden_size])
|
||||
self.parent.assertEqual(len(decoder_past), 2)
|
||||
|
@ -116,7 +116,7 @@ class TFT5ModelTester:
|
|||
|
||||
result = model(inputs_dict)
|
||||
|
||||
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.seq_length, self.vocab_size])
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
|
||||
def create_and_check_t5_decoder_model_past(self, config, input_ids, decoder_input_ids, attention_mask):
|
||||
model = TFT5Model(config=config).get_decoder()
|
||||
|
|
|
@ -97,26 +97,15 @@ class TFTransfoXLModelTester:
|
|||
|
||||
hidden_states_2, mems_2 = model(inputs).to_tuple()
|
||||
|
||||
result = {
|
||||
"hidden_states_1": hidden_states_1.numpy(),
|
||||
"mems_1": [mem.numpy() for mem in mems_1],
|
||||
"hidden_states_2": hidden_states_2.numpy(),
|
||||
"mems_2": [mem.numpy() for mem in mems_2],
|
||||
}
|
||||
|
||||
self.parent.assertEqual(hidden_states_1.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
self.parent.assertEqual(hidden_states_2.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
self.parent.assertListEqual(
|
||||
list(result["hidden_states_1"].shape), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
[mem.shape for mem in mems_1],
|
||||
[(self.mem_len, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
|
||||
)
|
||||
self.parent.assertListEqual(
|
||||
list(result["hidden_states_2"].shape), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
)
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.shape) for mem in result["mems_1"]),
|
||||
[[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
|
||||
)
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.shape) for mem in result["mems_2"]),
|
||||
[[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
|
||||
[mem.shape for mem in mems_2],
|
||||
[(self.mem_len, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
|
||||
)
|
||||
|
||||
def create_and_check_transfo_xl_lm_head(self, config, input_ids_1, input_ids_2, lm_labels):
|
||||
|
@ -133,27 +122,16 @@ class TFTransfoXLModelTester:
|
|||
|
||||
_, mems_2 = model(inputs).to_tuple()
|
||||
|
||||
result = {
|
||||
"mems_1": [mem.numpy() for mem in mems_1],
|
||||
"lm_logits_1": lm_logits_1.numpy(),
|
||||
"mems_2": [mem.numpy() for mem in mems_2],
|
||||
"lm_logits_2": lm_logits_2.numpy(),
|
||||
}
|
||||
|
||||
self.parent.assertEqual(lm_logits_1.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
self.parent.assertListEqual(
|
||||
list(result["lm_logits_1"].shape), [self.batch_size, self.seq_length, self.vocab_size]
|
||||
)
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.shape) for mem in result["mems_1"]),
|
||||
[[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
|
||||
[mem.shape for mem in mems_1],
|
||||
[(self.mem_len, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
|
||||
)
|
||||
|
||||
self.parent.assertEqual(lm_logits_2.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
self.parent.assertListEqual(
|
||||
list(result["lm_logits_2"].shape), [self.batch_size, self.seq_length, self.vocab_size]
|
||||
)
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.shape) for mem in result["mems_2"]),
|
||||
[[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
|
||||
[mem.shape for mem in mems_2],
|
||||
[(self.mem_len, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
|
|
|
@ -145,9 +145,7 @@ class TFXLMModelTester:
|
|||
|
||||
inputs = [input_ids, input_mask]
|
||||
result = model(inputs)
|
||||
self.parent.assertListEqual(
|
||||
list(result["last_hidden_state"].shape), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
)
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
|
||||
def create_and_check_xlm_lm_head(
|
||||
self,
|
||||
|
@ -168,7 +166,7 @@ class TFXLMModelTester:
|
|||
|
||||
result = outputs
|
||||
|
||||
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.seq_length, self.vocab_size])
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
|
||||
def create_and_check_xlm_qa(
|
||||
self,
|
||||
|
@ -188,8 +186,8 @@ class TFXLMModelTester:
|
|||
|
||||
result = model(inputs)
|
||||
|
||||
self.parent.assertListEqual(list(result["start_logits"].shape), [self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(list(result["end_logits"].shape), [self.batch_size, self.seq_length])
|
||||
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
|
||||
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
|
||||
|
||||
def create_and_check_xlm_sequence_classif(
|
||||
self,
|
||||
|
@ -209,7 +207,7 @@ class TFXLMModelTester:
|
|||
|
||||
result = model(inputs)
|
||||
|
||||
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.type_sequence_label_size])
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
|
||||
|
||||
def create_and_check_xlm_for_token_classification(
|
||||
self,
|
||||
|
@ -227,7 +225,7 @@ class TFXLMModelTester:
|
|||
model = TFXLMForTokenClassification(config=config)
|
||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||
result = model(inputs)
|
||||
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.seq_length, self.num_labels])
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
|
||||
|
||||
def create_and_check_xlm_for_multiple_choice(
|
||||
self,
|
||||
|
@ -252,7 +250,7 @@ class TFXLMModelTester:
|
|||
"token_type_ids": multiple_choice_token_type_ids,
|
||||
}
|
||||
result = model(inputs)
|
||||
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_choices])
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
|
|
|
@ -158,12 +158,10 @@ class TFXLNetModelTester:
|
|||
no_mems_outputs = model(inputs)
|
||||
self.parent.assertEqual(len(no_mems_outputs), 1)
|
||||
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
self.parent.assertListEqual(
|
||||
list(result["last_hidden_state"].shape), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
)
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.shape) for mem in result["mems"]),
|
||||
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
|
||||
[mem.shape for mem in result.mems],
|
||||
[(self.seq_length, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
|
||||
)
|
||||
|
||||
def create_and_check_xlnet_lm_head(
|
||||
|
@ -191,27 +189,15 @@ class TFXLNetModelTester:
|
|||
inputs_3 = {"input_ids": input_ids_q, "perm_mask": perm_mask, "target_mapping": target_mapping}
|
||||
logits, _ = model(inputs_3).to_tuple()
|
||||
|
||||
result = {
|
||||
"mems_1": [mem.numpy() for mem in mems_1],
|
||||
"all_logits_1": all_logits_1.numpy(),
|
||||
"mems_2": [mem.numpy() for mem in mems_2],
|
||||
"all_logits_2": all_logits_2.numpy(),
|
||||
}
|
||||
|
||||
self.parent.assertEqual(all_logits_1.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
self.parent.assertListEqual(
|
||||
list(result["all_logits_1"].shape), [self.batch_size, self.seq_length, self.vocab_size]
|
||||
[mem.shape for mem in mems_1],
|
||||
[(self.seq_length, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
|
||||
)
|
||||
self.parent.assertEqual(all_logits_2.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.shape) for mem in result["mems_1"]),
|
||||
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
|
||||
)
|
||||
|
||||
self.parent.assertListEqual(
|
||||
list(result["all_logits_2"].shape), [self.batch_size, self.seq_length, self.vocab_size]
|
||||
)
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.shape) for mem in result["mems_2"]),
|
||||
[[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
|
||||
[mem.shape for mem in mems_2],
|
||||
[(self.mem_len, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
|
||||
)
|
||||
|
||||
def create_and_check_xlnet_qa(
|
||||
|
@ -233,11 +219,11 @@ class TFXLNetModelTester:
|
|||
inputs = {"input_ids": input_ids_1, "attention_mask": input_mask, "token_type_ids": segment_ids}
|
||||
result = model(inputs)
|
||||
|
||||
self.parent.assertListEqual(list(result["start_logits"].shape), [self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(list(result["end_logits"].shape), [self.batch_size, self.seq_length])
|
||||
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
|
||||
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.shape) for mem in result["mems"]),
|
||||
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
|
||||
[mem.shape for mem in result.mems],
|
||||
[(self.seq_length, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
|
||||
)
|
||||
|
||||
def create_and_check_xlnet_sequence_classif(
|
||||
|
@ -258,10 +244,10 @@ class TFXLNetModelTester:
|
|||
|
||||
result = model(input_ids_1)
|
||||
|
||||
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.type_sequence_label_size])
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.shape) for mem in result["mems"]),
|
||||
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
|
||||
[mem.shape for mem in result.mems],
|
||||
[(self.seq_length, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
|
||||
)
|
||||
|
||||
def create_and_check_xlnet_for_token_classification(
|
||||
|
@ -286,12 +272,10 @@ class TFXLNetModelTester:
|
|||
# 'token_type_ids': token_type_ids
|
||||
}
|
||||
result = model(inputs)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, config.num_labels))
|
||||
self.parent.assertListEqual(
|
||||
list(result["logits"].shape), [self.batch_size, self.seq_length, config.num_labels]
|
||||
)
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.shape) for mem in result["mems"]),
|
||||
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
|
||||
[mem.shape for mem in result.mems],
|
||||
[(self.seq_length, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
|
||||
)
|
||||
|
||||
def create_and_check_xlnet_for_multiple_choice(
|
||||
|
@ -320,10 +304,10 @@ class TFXLNetModelTester:
|
|||
}
|
||||
result = model(inputs)
|
||||
|
||||
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_choices])
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.shape) for mem in result["mems"]),
|
||||
[[self.seq_length, self.batch_size * self.num_choices, self.hidden_size]] * self.num_hidden_layers,
|
||||
[mem.shape for mem in result.mems],
|
||||
[(self.seq_length, self.batch_size * self.num_choices, self.hidden_size)] * self.num_hidden_layers,
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
|
|
|
@ -100,19 +100,15 @@ class TransfoXLModelTester:
|
|||
return outputs
|
||||
|
||||
def check_transfo_xl_model_output(self, result):
|
||||
self.parent.assertEqual(result["hidden_states_1"].shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
self.parent.assertEqual(result["hidden_states_2"].shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
self.parent.assertListEqual(
|
||||
list(result["hidden_states_1"].size()), [self.batch_size, self.seq_length, self.hidden_size],
|
||||
[mem.shape for mem in result["mems_1"]],
|
||||
[(self.mem_len, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
|
||||
)
|
||||
self.parent.assertListEqual(
|
||||
list(result["hidden_states_2"].size()), [self.batch_size, self.seq_length, self.hidden_size],
|
||||
)
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result["mems_1"]),
|
||||
[[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
|
||||
)
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result["mems_2"]),
|
||||
[[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
|
||||
[mem.shape for mem in result["mems_2"]],
|
||||
[(self.mem_len, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
|
||||
)
|
||||
|
||||
def create_transfo_xl_lm_head(self, config, input_ids_1, input_ids_2, lm_labels):
|
||||
|
@ -136,22 +132,18 @@ class TransfoXLModelTester:
|
|||
return outputs
|
||||
|
||||
def check_transfo_xl_lm_head_output(self, result):
|
||||
self.parent.assertListEqual(list(result["loss_1"].size()), [self.batch_size, self.seq_length - 1])
|
||||
self.parent.assertEqual(result["loss_1"].shape, (self.batch_size, self.seq_length - 1))
|
||||
self.parent.assertEqual(result["lm_logits_1"].shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
self.parent.assertListEqual(
|
||||
list(result["lm_logits_1"].size()), [self.batch_size, self.seq_length, self.vocab_size],
|
||||
)
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result["mems_1"]),
|
||||
[[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
|
||||
[mem.shape for mem in result["mems_1"]],
|
||||
[(self.mem_len, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
|
||||
)
|
||||
|
||||
self.parent.assertListEqual(list(result["loss_2"].size()), [self.batch_size, self.seq_length - 1])
|
||||
self.parent.assertEqual(result["loss_2"].shape, (self.batch_size, self.seq_length - 1))
|
||||
self.parent.assertEqual(result["lm_logits_2"].shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
self.parent.assertListEqual(
|
||||
list(result["lm_logits_2"].size()), [self.batch_size, self.seq_length, self.vocab_size],
|
||||
)
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result["mems_2"]),
|
||||
[[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
|
||||
[mem.shape for mem in result["mems_2"]],
|
||||
[(self.mem_len, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
|
|
|
@ -192,8 +192,8 @@ class XLNetModelTester:
|
|||
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result["mems"]),
|
||||
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
|
||||
[mem.shape for mem in result.mems],
|
||||
[(self.seq_length, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
|
||||
)
|
||||
|
||||
def create_and_check_xlnet_model_use_cache(
|
||||
|
@ -305,22 +305,22 @@ class XLNetModelTester:
|
|||
|
||||
result1 = model(input_ids_1, token_type_ids=segment_ids, labels=lm_labels)
|
||||
|
||||
result2 = model(input_ids_2, token_type_ids=segment_ids, labels=lm_labels, mems=result1["mems"])
|
||||
result2 = model(input_ids_2, token_type_ids=segment_ids, labels=lm_labels, mems=result1.mems)
|
||||
|
||||
_ = model(input_ids_q, perm_mask=perm_mask, target_mapping=target_mapping)
|
||||
|
||||
self.parent.assertEqual(result1.loss.shape, ())
|
||||
self.parent.assertEqual(result1.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result1["mems"]),
|
||||
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
|
||||
[mem.shape for mem in result1.mems],
|
||||
[(self.seq_length, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
|
||||
)
|
||||
|
||||
self.parent.assertEqual(result2.loss.shape, ())
|
||||
self.parent.assertEqual(result2.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result2["mems"]),
|
||||
[[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
|
||||
[mem.shape for mem in result2.mems],
|
||||
[(self.mem_len, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
|
||||
)
|
||||
|
||||
def create_and_check_xlnet_qa(
|
||||
|
@ -378,8 +378,8 @@ class XLNetModelTester:
|
|||
)
|
||||
self.parent.assertEqual(result.cls_logits.shape, (self.batch_size,))
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result["mems"]),
|
||||
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
|
||||
[mem.shape for mem in result.mems],
|
||||
[(self.seq_length, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
|
||||
)
|
||||
|
||||
def create_and_check_xlnet_token_classif(
|
||||
|
@ -407,8 +407,8 @@ class XLNetModelTester:
|
|||
self.parent.assertEqual(result.loss.shape, ())
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.type_sequence_label_size))
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result["mems"]),
|
||||
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
|
||||
[mem.shape for mem in result.mems],
|
||||
[(self.seq_length, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
|
||||
)
|
||||
|
||||
def create_and_check_xlnet_sequence_classif(
|
||||
|
@ -436,8 +436,8 @@ class XLNetModelTester:
|
|||
self.parent.assertEqual(result.loss.shape, ())
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result["mems"]),
|
||||
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers,
|
||||
[mem.shape for mem in result.mems],
|
||||
[(self.seq_length, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
|
|
Loading…
Reference in New Issue