diff --git a/tests/test_modeling_tf_gpt2.py b/tests/test_modeling_tf_gpt2.py index c0781a2947..9f58fabe57 100644 --- a/tests/test_modeling_tf_gpt2.py +++ b/tests/test_modeling_tf_gpt2.py @@ -218,6 +218,11 @@ class TFGPT2ModelTester: ): model = TFGPT2Model(config=config) + input_ids = input_ids[:1, :] + input_mask = input_mask[:1, :] + token_type_ids = token_type_ids[:1, :] + self.batch_size = 1 + # first forward pass outputs = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, use_cache=True) @@ -225,13 +230,13 @@ class TFGPT2ModelTester: # create hypothetical next token and extent to next_input_ids next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) - next_token_types = ids_tensor((self.batch_size, 3), self.type_vocab_size) next_attn_mask = ids_tensor((self.batch_size, 3), 2) + next_token_types = ids_tensor((self.batch_size, 3), self.type_vocab_size) # append to next input_ids and token_type_ids next_input_ids = tf.concat([input_ids, next_tokens], axis=-1) - next_token_type_ids = tf.concat([token_type_ids, next_token_types], axis=-1) next_attention_mask = tf.concat([input_mask, next_attn_mask], axis=-1) + next_token_type_ids = tf.concat([token_type_ids, next_token_types], axis=-1) output_from_no_past = model( next_input_ids, token_type_ids=next_token_type_ids, attention_mask=next_attention_mask