From 83d803ba02fdd9a1a7f8ed31a5e2910e2835dad0 Mon Sep 17 00:00:00 2001 From: Julien Plu Date: Wed, 17 Feb 2021 17:48:56 +0100 Subject: [PATCH] Making TF BART-like models XLA and AMP compliant (#10191) * Update BART * Update Blenderbot * Update BlenderbotSmall * Update Marian * Update MBart * Update MBart * Update Pegasus * Update template * Fix Marian and Pegasus * Apply style * Default initializer * Default initializer * Default initializer * Remove int32 casts * Fix template * Remove more cast --- .../models/bart/modeling_tf_bart.py | 110 ++++++++------ .../blenderbot/modeling_tf_blenderbot.py | 110 ++++++++------ .../modeling_tf_blenderbot_small.py | 110 ++++++++------ .../models/marian/modeling_tf_marian.py | 140 +++++++++++------- .../models/mbart/modeling_tf_mbart.py | 107 +++++++------ .../models/pegasus/modeling_tf_pegasus.py | 140 +++++++++++------- ...tf_{{cookiecutter.lowercase_modelname}}.py | 86 ++++++----- tests/test_modeling_tf_bart.py | 8 - tests/test_modeling_tf_blenderbot.py | 8 - tests/test_modeling_tf_blenderbot_small.py | 8 - tests/test_modeling_tf_marian.py | 8 - tests/test_modeling_tf_mbart.py | 16 +- tests/test_modeling_tf_pegasus.py | 8 - 13 files changed, 492 insertions(+), 367 deletions(-) diff --git a/src/transformers/models/bart/modeling_tf_bart.py b/src/transformers/models/bart/modeling_tf_bart.py index 6d00beafc0..cba5767bca 100644 --- a/src/transformers/models/bart/modeling_tf_bart.py +++ b/src/transformers/models/bart/modeling_tf_bart.py @@ -60,8 +60,7 @@ LARGE_NEGATIVE = -1e8 def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int): - shifted_input_ids = tf.cast(input_ids, tf.int32) - shifted_input_ids = tf.roll(shifted_input_ids, 1, axis=-1) + shifted_input_ids = tf.roll(input_ids, 1, axis=-1) start_tokens = tf.fill((shape_list(shifted_input_ids)[0], 1), decoder_start_token_id) shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, 1:]], -1) # replace possible -100 values in labels by `pad_token_id` @@ -69,12 +68,13 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids ) - # "Verify that `labels` has only positive values and -100" - assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.cast(0, tf.int32)) + if tf.executing_eagerly(): + # "Verify that `labels` has only positive values and -100" + assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0)) - # Make sure the assertion op is called by wrapping the result in an identity no-op - with tf.control_dependencies([assert_gte0]): - shifted_input_ids = tf.identity(shifted_input_ids) + # Make sure the assertion op is called by wrapping the result in an identity no-op + with tf.control_dependencies([assert_gte0]): + shifted_input_ids = tf.identity(shifted_input_ids) return shifted_input_ids @@ -84,14 +84,13 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - mask = tf.ones((tgt_len, tgt_len), dtype=tf.float32) * LARGE_NEGATIVE + mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE mask_cond = tf.range(shape_list(mask)[-1]) mask = tf.where(mask_cond < tf.reshape(mask_cond + 1, (shape_list(mask)[-1], 1)), 0.0, mask) - mask = tf.cast(mask, tf.float32) if past_key_values_length > 0: - mask = tf.concat([tf.zeros((tgt_len, past_key_values_length), dtype=tf.float32), mask], axis=-1) + mask = tf.concat([tf.zeros((tgt_len, past_key_values_length)), mask], axis=-1) return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1)) @@ -102,9 +101,11 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values """ src_len = shape_list(mask)[1] tgt_len = tgt_len if tgt_len is not None else src_len - expanded_mask = tf.cast(tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)), tf.float32) + one_cst = tf.constant(1.0) + mask = tf.cast(mask, dtype=one_cst.dtype) + expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)) - return (1.0 - expanded_mask) * LARGE_NEGATIVE + return (one_cst - expanded_mask) * LARGE_NEGATIVE class TFBartLearnedPositionalEmbedding(TFSharedEmbeddings): @@ -123,9 +124,7 @@ class TFBartLearnedPositionalEmbedding(TFSharedEmbeddings): """Input is expected to be of size [bsz x seqlen].""" bsz, seq_len = input_shape[:2] - positions = tf.range( - past_key_values_length, seq_len + past_key_values_length, delta=1, dtype=tf.int32, name="range" - ) + positions = tf.range(past_key_values_length, seq_len + past_key_values_length, delta=1, name="range") return super().call(positions + self.offset) @@ -215,43 +214,57 @@ class TFBartAttention(tf.keras.layers.Layer): src_len = shape_list(key_states)[1] attn_weights = tf.matmul(query_states, key_states, transpose_b=True) - tf.debugging.assert_equal( - shape_list(attn_weights), - [bsz * self.num_heads, tgt_len, src_len], - message=f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {shape_list(attn_weights)}", - ) + # The tf.debugging asserts are not compliant with XLA then they + # have to be disabled in other modes than eager. + if tf.executing_eagerly(): + tf.debugging.assert_equal( + shape_list(attn_weights), + [bsz * self.num_heads, tgt_len, src_len], + message=f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {shape_list(attn_weights)}", + ) if attention_mask is not None: - tf.debugging.assert_equal( - shape_list(attention_mask), - [bsz, 1, tgt_len, src_len], - message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}", - ) + # The tf.debugging asserts are not compliant with XLA then they + # have to be disabled in other modes than eager. + if tf.executing_eagerly(): + tf.debugging.assert_equal( + shape_list(attention_mask), + [bsz, 1, tgt_len, src_len], + message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}", + ) + + attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype) attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) attn_weights = tf.nn.softmax(attn_weights, axis=-1) if layer_head_mask is not None: - tf.debugging.assert_equal( - shape_list(layer_head_mask), - [self.num_heads], - message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}", - ) + # The tf.debugging asserts are not compliant with XLA then they + # have to be disabled in other modes than eager. + if tf.executing_eagerly(): + tf.debugging.assert_equal( + shape_list(layer_head_mask), + [self.num_heads], + message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}", + ) + attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape( attn_weights, (bsz, self.num_heads, tgt_len, src_len) ) attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) attn_probs = self.dropout(attn_weights, training=training) - attn_output = tf.matmul(attn_probs, value_states) - tf.debugging.assert_equal( - shape_list(attn_output), - [bsz * self.num_heads, tgt_len, self.head_dim], - message=f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {shape_list(attn_output)}", - ) + # The tf.debugging asserts are not compliant with XLA then they + # have to be disabled in other modes than eager. + if tf.executing_eagerly(): + tf.debugging.assert_equal( + shape_list(attn_output), + [bsz * self.num_heads, tgt_len, self.head_dim], + message=f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {shape_list(attn_output)}", + ) attn_output = tf.transpose( tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3) @@ -292,11 +305,16 @@ class TFBartEncoderLayer(tf.keras.layers.Layer): hidden_states, self_attn_weights, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask ) - tf.debugging.assert_equal( - shape_list(hidden_states), - shape_list(residual), - message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}", - ) + + # The tf.debugging asserts are not compliant with XLA then they + # have to be disabled in other modes than eager. + if tf.executing_eagerly(): + tf.debugging.assert_equal( + shape_list(hidden_states), + shape_list(residual), + message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}", + ) + hidden_states = self.dropout(hidden_states, training=training) hidden_states = residual + hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) @@ -717,12 +735,15 @@ class TFBartEncoder(tf.keras.layers.Layer): all_attentions = () if inputs["output_attentions"] else None # check if head_mask has a correct number of layers specified if desired - if inputs["head_mask"] is not None: + # The tf.debugging asserts are not compliant with XLA then they + # have to be disabled in other modes than eager. + if inputs["head_mask"] is not None and tf.executing_eagerly(): tf.debugging.assert_equal( shape_list(inputs["head_mask"])[0], len(self.layers), message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.", ) + # encoder layers for idx, encoder_layer in enumerate(self.layers): @@ -933,12 +954,15 @@ class TFBartDecoder(tf.keras.layers.Layer): present_key_values = () if inputs["use_cache"] else None # check if head_mask has a correct number of layers specified if desired - if inputs["head_mask"] is not None: + # The tf.debugging asserts are not compliant with XLA then they + # have to be disabled in other modes than eager. + if inputs["head_mask"] is not None and tf.executing_eagerly(): tf.debugging.assert_equal( shape_list(inputs["head_mask"])[0], len(self.layers), message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.", ) + for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) if inputs["output_hidden_states"]: diff --git a/src/transformers/models/blenderbot/modeling_tf_blenderbot.py b/src/transformers/models/blenderbot/modeling_tf_blenderbot.py index c78392a4f3..7e517e44de 100644 --- a/src/transformers/models/blenderbot/modeling_tf_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_tf_blenderbot.py @@ -63,8 +63,7 @@ LARGE_NEGATIVE = -1e8 # Copied from transformers.models.bart.modeling_tf_bart.shift_tokens_right def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int): - shifted_input_ids = tf.cast(input_ids, tf.int32) - shifted_input_ids = tf.roll(shifted_input_ids, 1, axis=-1) + shifted_input_ids = tf.roll(input_ids, 1, axis=-1) start_tokens = tf.fill((shape_list(shifted_input_ids)[0], 1), decoder_start_token_id) shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, 1:]], -1) # replace possible -100 values in labels by `pad_token_id` @@ -72,12 +71,13 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids ) - # "Verify that `labels` has only positive values and -100" - assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.cast(0, tf.int32)) + if tf.executing_eagerly(): + # "Verify that `labels` has only positive values and -100" + assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0)) - # Make sure the assertion op is called by wrapping the result in an identity no-op - with tf.control_dependencies([assert_gte0]): - shifted_input_ids = tf.identity(shifted_input_ids) + # Make sure the assertion op is called by wrapping the result in an identity no-op + with tf.control_dependencies([assert_gte0]): + shifted_input_ids = tf.identity(shifted_input_ids) return shifted_input_ids @@ -88,14 +88,13 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - mask = tf.ones((tgt_len, tgt_len), dtype=tf.float32) * LARGE_NEGATIVE + mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE mask_cond = tf.range(shape_list(mask)[-1]) mask = tf.where(mask_cond < tf.reshape(mask_cond + 1, (shape_list(mask)[-1], 1)), 0.0, mask) - mask = tf.cast(mask, tf.float32) if past_key_values_length > 0: - mask = tf.concat([tf.zeros((tgt_len, past_key_values_length), dtype=tf.float32), mask], axis=-1) + mask = tf.concat([tf.zeros((tgt_len, past_key_values_length)), mask], axis=-1) return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1)) @@ -107,9 +106,11 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values """ src_len = shape_list(mask)[1] tgt_len = tgt_len if tgt_len is not None else src_len - expanded_mask = tf.cast(tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)), tf.float32) + one_cst = tf.constant(1.0) + mask = tf.cast(mask, dtype=one_cst.dtype) + expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)) - return (1.0 - expanded_mask) * LARGE_NEGATIVE + return (one_cst - expanded_mask) * LARGE_NEGATIVE class TFBlenderbotLearnedPositionalEmbedding(TFSharedEmbeddings): @@ -125,9 +126,7 @@ class TFBlenderbotLearnedPositionalEmbedding(TFSharedEmbeddings): """Input is expected to be of size [bsz x seqlen].""" bsz, seq_len = input_shape[:2] - positions = tf.range( - past_key_values_length, seq_len + past_key_values_length, delta=1, dtype=tf.int32, name="range" - ) + positions = tf.range(past_key_values_length, seq_len + past_key_values_length, delta=1, name="range") return super().call(positions) @@ -218,43 +217,57 @@ class TFBlenderbotAttention(tf.keras.layers.Layer): src_len = shape_list(key_states)[1] attn_weights = tf.matmul(query_states, key_states, transpose_b=True) - tf.debugging.assert_equal( - shape_list(attn_weights), - [bsz * self.num_heads, tgt_len, src_len], - message=f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {shape_list(attn_weights)}", - ) + # The tf.debugging asserts are not compliant with XLA then they + # have to be disabled in other modes than eager. + if tf.executing_eagerly(): + tf.debugging.assert_equal( + shape_list(attn_weights), + [bsz * self.num_heads, tgt_len, src_len], + message=f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {shape_list(attn_weights)}", + ) if attention_mask is not None: - tf.debugging.assert_equal( - shape_list(attention_mask), - [bsz, 1, tgt_len, src_len], - message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}", - ) + # The tf.debugging asserts are not compliant with XLA then they + # have to be disabled in other modes than eager. + if tf.executing_eagerly(): + tf.debugging.assert_equal( + shape_list(attention_mask), + [bsz, 1, tgt_len, src_len], + message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}", + ) + + attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype) attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) attn_weights = tf.nn.softmax(attn_weights, axis=-1) if layer_head_mask is not None: - tf.debugging.assert_equal( - shape_list(layer_head_mask), - [self.num_heads], - message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}", - ) + # The tf.debugging asserts are not compliant with XLA then they + # have to be disabled in other modes than eager. + if tf.executing_eagerly(): + tf.debugging.assert_equal( + shape_list(layer_head_mask), + [self.num_heads], + message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}", + ) + attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape( attn_weights, (bsz, self.num_heads, tgt_len, src_len) ) attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) attn_probs = self.dropout(attn_weights, training=training) - attn_output = tf.matmul(attn_probs, value_states) - tf.debugging.assert_equal( - shape_list(attn_output), - [bsz * self.num_heads, tgt_len, self.head_dim], - message=f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {shape_list(attn_output)}", - ) + # The tf.debugging asserts are not compliant with XLA then they + # have to be disabled in other modes than eager. + if tf.executing_eagerly(): + tf.debugging.assert_equal( + shape_list(attn_output), + [bsz * self.num_heads, tgt_len, self.head_dim], + message=f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {shape_list(attn_output)}", + ) attn_output = tf.transpose( tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3) @@ -297,11 +310,16 @@ class TFBlenderbotEncoderLayer(tf.keras.layers.Layer): hidden_states, self_attn_weights, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask ) - tf.debugging.assert_equal( - shape_list(hidden_states), - shape_list(residual), - message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}", - ) + + # The tf.debugging asserts are not compliant with XLA then they + # have to be disabled in other modes than eager. + if tf.executing_eagerly(): + tf.debugging.assert_equal( + shape_list(hidden_states), + shape_list(residual), + message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}", + ) + hidden_states = self.dropout(hidden_states, training=training) hidden_states = residual + hidden_states @@ -719,12 +737,15 @@ class TFBlenderbotEncoder(tf.keras.layers.Layer): all_attentions = () if inputs["output_attentions"] else None # check if head_mask has a correct number of layers specified if desired - if inputs["head_mask"] is not None: + # The tf.debugging asserts are not compliant with XLA then they + # have to be disabled in other modes than eager. + if inputs["head_mask"] is not None and tf.executing_eagerly(): tf.debugging.assert_equal( shape_list(inputs["head_mask"])[0], len(self.layers), message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.", ) + # encoder layers for idx, encoder_layer in enumerate(self.layers): @@ -943,12 +964,15 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer): present_key_values = () # check if head_mask has a correct number of layers specified if desired - if inputs["head_mask"] is not None: + # The tf.debugging asserts are not compliant with XLA then they + # have to be disabled in other modes than eager. + if inputs["head_mask"] is not None and tf.executing_eagerly(): tf.debugging.assert_equal( shape_list(inputs["head_mask"])[0], len(self.layers), message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.", ) + for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) if inputs["output_hidden_states"]: diff --git a/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py index 8500d177fa..6d92645dd5 100644 --- a/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py @@ -61,8 +61,7 @@ LARGE_NEGATIVE = -1e8 # Copied from transformers.models.bart.modeling_tf_bart.shift_tokens_right def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int): - shifted_input_ids = tf.cast(input_ids, tf.int32) - shifted_input_ids = tf.roll(shifted_input_ids, 1, axis=-1) + shifted_input_ids = tf.roll(input_ids, 1, axis=-1) start_tokens = tf.fill((shape_list(shifted_input_ids)[0], 1), decoder_start_token_id) shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, 1:]], -1) # replace possible -100 values in labels by `pad_token_id` @@ -70,12 +69,13 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids ) - # "Verify that `labels` has only positive values and -100" - assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.cast(0, tf.int32)) + if tf.executing_eagerly(): + # "Verify that `labels` has only positive values and -100" + assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0)) - # Make sure the assertion op is called by wrapping the result in an identity no-op - with tf.control_dependencies([assert_gte0]): - shifted_input_ids = tf.identity(shifted_input_ids) + # Make sure the assertion op is called by wrapping the result in an identity no-op + with tf.control_dependencies([assert_gte0]): + shifted_input_ids = tf.identity(shifted_input_ids) return shifted_input_ids @@ -86,14 +86,13 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - mask = tf.ones((tgt_len, tgt_len), dtype=tf.float32) * LARGE_NEGATIVE + mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE mask_cond = tf.range(shape_list(mask)[-1]) mask = tf.where(mask_cond < tf.reshape(mask_cond + 1, (shape_list(mask)[-1], 1)), 0.0, mask) - mask = tf.cast(mask, tf.float32) if past_key_values_length > 0: - mask = tf.concat([tf.zeros((tgt_len, past_key_values_length), dtype=tf.float32), mask], axis=-1) + mask = tf.concat([tf.zeros((tgt_len, past_key_values_length)), mask], axis=-1) return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1)) @@ -105,9 +104,11 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values """ src_len = shape_list(mask)[1] tgt_len = tgt_len if tgt_len is not None else src_len - expanded_mask = tf.cast(tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)), tf.float32) + one_cst = tf.constant(1.0) + mask = tf.cast(mask, dtype=one_cst.dtype) + expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)) - return (1.0 - expanded_mask) * LARGE_NEGATIVE + return (one_cst - expanded_mask) * LARGE_NEGATIVE # Copied from transformers.models.blenderbot.modeling_tf_blenderbot.TFBlenderbotLearnedPositionalEmbedding with Blenderbot->BlenderbotSmall @@ -124,9 +125,7 @@ class TFBlenderbotSmallLearnedPositionalEmbedding(TFSharedEmbeddings): """Input is expected to be of size [bsz x seqlen].""" bsz, seq_len = input_shape[:2] - positions = tf.range( - past_key_values_length, seq_len + past_key_values_length, delta=1, dtype=tf.int32, name="range" - ) + positions = tf.range(past_key_values_length, seq_len + past_key_values_length, delta=1, name="range") return super().call(positions) @@ -217,43 +216,57 @@ class TFBlenderbotSmallAttention(tf.keras.layers.Layer): src_len = shape_list(key_states)[1] attn_weights = tf.matmul(query_states, key_states, transpose_b=True) - tf.debugging.assert_equal( - shape_list(attn_weights), - [bsz * self.num_heads, tgt_len, src_len], - message=f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {shape_list(attn_weights)}", - ) + # The tf.debugging asserts are not compliant with XLA then they + # have to be disabled in other modes than eager. + if tf.executing_eagerly(): + tf.debugging.assert_equal( + shape_list(attn_weights), + [bsz * self.num_heads, tgt_len, src_len], + message=f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {shape_list(attn_weights)}", + ) if attention_mask is not None: - tf.debugging.assert_equal( - shape_list(attention_mask), - [bsz, 1, tgt_len, src_len], - message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}", - ) + # The tf.debugging asserts are not compliant with XLA then they + # have to be disabled in other modes than eager. + if tf.executing_eagerly(): + tf.debugging.assert_equal( + shape_list(attention_mask), + [bsz, 1, tgt_len, src_len], + message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}", + ) + + attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype) attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) attn_weights = tf.nn.softmax(attn_weights, axis=-1) if layer_head_mask is not None: - tf.debugging.assert_equal( - shape_list(layer_head_mask), - [self.num_heads], - message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}", - ) + # The tf.debugging asserts are not compliant with XLA then they + # have to be disabled in other modes than eager. + if tf.executing_eagerly(): + tf.debugging.assert_equal( + shape_list(layer_head_mask), + [self.num_heads], + message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}", + ) + attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape( attn_weights, (bsz, self.num_heads, tgt_len, src_len) ) attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) attn_probs = self.dropout(attn_weights, training=training) - attn_output = tf.matmul(attn_probs, value_states) - tf.debugging.assert_equal( - shape_list(attn_output), - [bsz * self.num_heads, tgt_len, self.head_dim], - message=f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {shape_list(attn_output)}", - ) + # The tf.debugging asserts are not compliant with XLA then they + # have to be disabled in other modes than eager. + if tf.executing_eagerly(): + tf.debugging.assert_equal( + shape_list(attn_output), + [bsz * self.num_heads, tgt_len, self.head_dim], + message=f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {shape_list(attn_output)}", + ) attn_output = tf.transpose( tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3) @@ -295,11 +308,16 @@ class TFBlenderbotSmallEncoderLayer(tf.keras.layers.Layer): hidden_states, self_attn_weights, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask ) - tf.debugging.assert_equal( - shape_list(hidden_states), - shape_list(residual), - message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}", - ) + + # The tf.debugging asserts are not compliant with XLA then they + # have to be disabled in other modes than eager. + if tf.executing_eagerly(): + tf.debugging.assert_equal( + shape_list(hidden_states), + shape_list(residual), + message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}", + ) + hidden_states = self.dropout(hidden_states, training=training) hidden_states = residual + hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) @@ -725,12 +743,15 @@ class TFBlenderbotSmallEncoder(tf.keras.layers.Layer): all_attentions = () if inputs["output_attentions"] else None # check if head_mask has a correct number of layers specified if desired - if inputs["head_mask"] is not None: + # The tf.debugging asserts are not compliant with XLA then they + # have to be disabled in other modes than eager. + if inputs["head_mask"] is not None and tf.executing_eagerly(): tf.debugging.assert_equal( shape_list(inputs["head_mask"])[0], len(self.layers), message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.", ) + # encoder layers for idx, encoder_layer in enumerate(self.layers): @@ -946,12 +967,15 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer): present_key_values = () # check if head_mask has a correct number of layers specified if desired - if inputs["head_mask"] is not None: + # The tf.debugging asserts are not compliant with XLA then they + # have to be disabled in other modes than eager. + if inputs["head_mask"] is not None and tf.executing_eagerly(): tf.debugging.assert_equal( shape_list(inputs["head_mask"])[0], len(self.layers), message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.", ) + for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) if inputs["output_hidden_states"]: diff --git a/src/transformers/models/marian/modeling_tf_marian.py b/src/transformers/models/marian/modeling_tf_marian.py index 27004bf955..c45189cb1c 100644 --- a/src/transformers/models/marian/modeling_tf_marian.py +++ b/src/transformers/models/marian/modeling_tf_marian.py @@ -62,8 +62,7 @@ LARGE_NEGATIVE = -1e8 # Copied from transformers.models.bart.modeling_tf_bart.shift_tokens_right def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int): - shifted_input_ids = tf.cast(input_ids, tf.int32) - shifted_input_ids = tf.roll(shifted_input_ids, 1, axis=-1) + shifted_input_ids = tf.roll(input_ids, 1, axis=-1) start_tokens = tf.fill((shape_list(shifted_input_ids)[0], 1), decoder_start_token_id) shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, 1:]], -1) # replace possible -100 values in labels by `pad_token_id` @@ -71,12 +70,13 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids ) - # "Verify that `labels` has only positive values and -100" - assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.cast(0, tf.int32)) + if tf.executing_eagerly(): + # "Verify that `labels` has only positive values and -100" + assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0)) - # Make sure the assertion op is called by wrapping the result in an identity no-op - with tf.control_dependencies([assert_gte0]): - shifted_input_ids = tf.identity(shifted_input_ids) + # Make sure the assertion op is called by wrapping the result in an identity no-op + with tf.control_dependencies([assert_gte0]): + shifted_input_ids = tf.identity(shifted_input_ids) return shifted_input_ids @@ -87,14 +87,13 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - mask = tf.ones((tgt_len, tgt_len), dtype=tf.float32) * LARGE_NEGATIVE + mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE mask_cond = tf.range(shape_list(mask)[-1]) mask = tf.where(mask_cond < tf.reshape(mask_cond + 1, (shape_list(mask)[-1], 1)), 0.0, mask) - mask = tf.cast(mask, tf.float32) if past_key_values_length > 0: - mask = tf.concat([tf.zeros((tgt_len, past_key_values_length), dtype=tf.float32), mask], axis=-1) + mask = tf.concat([tf.zeros((tgt_len, past_key_values_length)), mask], axis=-1) return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1)) @@ -106,32 +105,42 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values """ src_len = shape_list(mask)[1] tgt_len = tgt_len if tgt_len is not None else src_len - expanded_mask = tf.cast(tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)), tf.float32) + one_cst = tf.constant(1.0) + mask = tf.cast(mask, dtype=one_cst.dtype) + expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)) - return (1.0 - expanded_mask) * LARGE_NEGATIVE + return (one_cst - expanded_mask) * LARGE_NEGATIVE -class TFMarianSinusoidalPositionalEmbedding(tf.keras.layers.Embedding): +class TFMarianSinusoidalPositionalEmbedding(tf.keras.layers.Layer): """This module produces sinusoidal positional embeddings of any length.""" def __init__(self, num_positions: int, embedding_dim: int, **kwargs): + super().__init__(**kwargs) if embedding_dim % 2 != 0: raise NotImplementedError(f"odd embedding_dim {embedding_dim} not supported") - super().__init__( - num_positions, - embedding_dim, - **kwargs, - ) + + self.embedding_dim = embedding_dim + self.num_positions = num_positions def build(self, input_shape: tf.TensorShape): """ Build shared token embedding layer Shared weights logic adapted from https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24 """ - super().build(input_shape) # Instantiates self.weight so it can be loaded - weight: np.ndarray = self._init_weight(self.input_dim, self.output_dim) - self.set_weights([weight]) # overwrite self.weight to correct value + + weight = self._init_weight(self.num_positions, self.embedding_dim) + + self.weight = self.add_weight( + name="embeddings", + shape=[self.num_positions, self.embedding_dim], + ) + weight = tf.cast(weight, dtype=self.weight.dtype) + + self.weight.assign(weight) + + super().build(input_shape) @staticmethod def _init_weight(n_pos: int, dim: int): @@ -146,7 +155,7 @@ class TFMarianSinusoidalPositionalEmbedding(tf.keras.layers.Embedding): position_enc[:, 0 : dim // 2] = np.sin(position_enc[:, 0::2]) position_enc[:, dim // 2 :] = np.cos(position_enc[:, 1::2]) # convert to tensor - table = tf.convert_to_tensor(position_enc, dtype=tf.float32) + table = tf.convert_to_tensor(position_enc) tf.stop_gradient(table) return table @@ -154,10 +163,8 @@ class TFMarianSinusoidalPositionalEmbedding(tf.keras.layers.Embedding): """Input is expected to be of size [bsz x seqlen].""" bsz, seq_len = input_shape[:2] - positions = tf.range( - past_key_values_length, seq_len + past_key_values_length, delta=1, dtype=tf.int32, name="range" - ) - return super().call(positions) + positions = tf.range(past_key_values_length, seq_len + past_key_values_length, delta=1, name="range") + return tf.gather(self.weight, positions) # Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention with Bart->Marian @@ -247,43 +254,57 @@ class TFMarianAttention(tf.keras.layers.Layer): src_len = shape_list(key_states)[1] attn_weights = tf.matmul(query_states, key_states, transpose_b=True) - tf.debugging.assert_equal( - shape_list(attn_weights), - [bsz * self.num_heads, tgt_len, src_len], - message=f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {shape_list(attn_weights)}", - ) + # The tf.debugging asserts are not compliant with XLA then they + # have to be disabled in other modes than eager. + if tf.executing_eagerly(): + tf.debugging.assert_equal( + shape_list(attn_weights), + [bsz * self.num_heads, tgt_len, src_len], + message=f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {shape_list(attn_weights)}", + ) if attention_mask is not None: - tf.debugging.assert_equal( - shape_list(attention_mask), - [bsz, 1, tgt_len, src_len], - message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}", - ) + # The tf.debugging asserts are not compliant with XLA then they + # have to be disabled in other modes than eager. + if tf.executing_eagerly(): + tf.debugging.assert_equal( + shape_list(attention_mask), + [bsz, 1, tgt_len, src_len], + message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}", + ) + + attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype) attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) attn_weights = tf.nn.softmax(attn_weights, axis=-1) if layer_head_mask is not None: - tf.debugging.assert_equal( - shape_list(layer_head_mask), - [self.num_heads], - message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}", - ) + # The tf.debugging asserts are not compliant with XLA then they + # have to be disabled in other modes than eager. + if tf.executing_eagerly(): + tf.debugging.assert_equal( + shape_list(layer_head_mask), + [self.num_heads], + message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}", + ) + attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape( attn_weights, (bsz, self.num_heads, tgt_len, src_len) ) attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) attn_probs = self.dropout(attn_weights, training=training) - attn_output = tf.matmul(attn_probs, value_states) - tf.debugging.assert_equal( - shape_list(attn_output), - [bsz * self.num_heads, tgt_len, self.head_dim], - message=f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {shape_list(attn_output)}", - ) + # The tf.debugging asserts are not compliant with XLA then they + # have to be disabled in other modes than eager. + if tf.executing_eagerly(): + tf.debugging.assert_equal( + shape_list(attn_output), + [bsz * self.num_heads, tgt_len, self.head_dim], + message=f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {shape_list(attn_output)}", + ) attn_output = tf.transpose( tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3) @@ -325,11 +346,16 @@ class TFMarianEncoderLayer(tf.keras.layers.Layer): hidden_states, self_attn_weights, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask ) - tf.debugging.assert_equal( - shape_list(hidden_states), - shape_list(residual), - message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}", - ) + + # The tf.debugging asserts are not compliant with XLA then they + # have to be disabled in other modes than eager. + if tf.executing_eagerly(): + tf.debugging.assert_equal( + shape_list(hidden_states), + shape_list(residual), + message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}", + ) + hidden_states = self.dropout(hidden_states, training=training) hidden_states = residual + hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) @@ -741,12 +767,15 @@ class TFMarianEncoder(tf.keras.layers.Layer): all_attentions = () if inputs["output_attentions"] else None # check if head_mask has a correct number of layers specified if desired - if inputs["head_mask"] is not None: + # The tf.debugging asserts are not compliant with XLA then they + # have to be disabled in other modes than eager. + if inputs["head_mask"] is not None and tf.executing_eagerly(): tf.debugging.assert_equal( shape_list(inputs["head_mask"])[0], len(self.layers), message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.", ) + # encoder layers for idx, encoder_layer in enumerate(self.layers): @@ -960,12 +989,15 @@ class TFMarianDecoder(tf.keras.layers.Layer): present_key_values = () # check if head_mask has a correct number of layers specified if desired - if inputs["head_mask"] is not None: + # The tf.debugging asserts are not compliant with XLA then they + # have to be disabled in other modes than eager. + if inputs["head_mask"] is not None and tf.executing_eagerly(): tf.debugging.assert_equal( shape_list(inputs["head_mask"])[0], len(self.layers), message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.", ) + for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) if inputs["output_hidden_states"]: diff --git a/src/transformers/models/mbart/modeling_tf_mbart.py b/src/transformers/models/mbart/modeling_tf_mbart.py index 6c7726b596..c031f073c8 100644 --- a/src/transformers/models/mbart/modeling_tf_mbart.py +++ b/src/transformers/models/mbart/modeling_tf_mbart.py @@ -64,19 +64,16 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int): Shift input ids one token to the right, and wrap the last non pad token (the token) Note that MBart does not have a single `decoder_start_token_id` in contrast to other Bart-like models. """ - prev_output_tokens = tf.cast(input_ids, tf.int32) assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." # replace possible -100 values in labels by `pad_token_id` - prev_output_tokens = tf.where( - prev_output_tokens == -100, tf.fill(shape_list(prev_output_tokens), pad_token_id), prev_output_tokens - ) + input_ids = tf.where(input_ids == -100, tf.fill(shape_list(input_ids), pad_token_id), input_ids) language_id_index = ( - tf.reduce_sum(tf.cast(tf.math.not_equal(prev_output_tokens, pad_token_id), tf.int32), axis=-1) - 1 + tf.reduce_sum(tf.cast(tf.math.not_equal(input_ids, pad_token_id), dtype=input_ids.dtype), axis=-1) - 1 ) language_id_index = tf.stack([tf.range(shape_list(input_ids)[0]), language_id_index], axis=-1) - languages_ids = tf.gather_nd(prev_output_tokens, language_id_index) + languages_ids = tf.gather_nd(input_ids, language_id_index) - shifted_input_ids = tf.concat([tf.expand_dims(languages_ids, axis=-1), prev_output_tokens[:, :-1]], axis=-1) + shifted_input_ids = tf.concat([tf.expand_dims(languages_ids, axis=-1), input_ids[:, :-1]], axis=-1) return shifted_input_ids @@ -87,14 +84,13 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - mask = tf.ones((tgt_len, tgt_len), dtype=tf.float32) * LARGE_NEGATIVE + mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE mask_cond = tf.range(shape_list(mask)[-1]) mask = tf.where(mask_cond < tf.reshape(mask_cond + 1, (shape_list(mask)[-1], 1)), 0.0, mask) - mask = tf.cast(mask, tf.float32) if past_key_values_length > 0: - mask = tf.concat([tf.zeros((tgt_len, past_key_values_length), dtype=tf.float32), mask], axis=-1) + mask = tf.concat([tf.zeros((tgt_len, past_key_values_length)), mask], axis=-1) return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1)) @@ -106,9 +102,11 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values """ src_len = shape_list(mask)[1] tgt_len = tgt_len if tgt_len is not None else src_len - expanded_mask = tf.cast(tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)), tf.float32) + one_cst = tf.constant(1.0) + mask = tf.cast(mask, dtype=one_cst.dtype) + expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)) - return (1.0 - expanded_mask) * LARGE_NEGATIVE + return (one_cst - expanded_mask) * LARGE_NEGATIVE # Copied from transformers.models.bart.modeling_tf_bart.TFBartLearnedPositionalEmbedding with Bart->MBart @@ -128,9 +126,7 @@ class TFMBartLearnedPositionalEmbedding(TFSharedEmbeddings): """Input is expected to be of size [bsz x seqlen].""" bsz, seq_len = input_shape[:2] - positions = tf.range( - past_key_values_length, seq_len + past_key_values_length, delta=1, dtype=tf.int32, name="range" - ) + positions = tf.range(past_key_values_length, seq_len + past_key_values_length, delta=1, name="range") return super().call(positions + self.offset) @@ -221,43 +217,57 @@ class TFMBartAttention(tf.keras.layers.Layer): src_len = shape_list(key_states)[1] attn_weights = tf.matmul(query_states, key_states, transpose_b=True) - tf.debugging.assert_equal( - shape_list(attn_weights), - [bsz * self.num_heads, tgt_len, src_len], - message=f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {shape_list(attn_weights)}", - ) + # The tf.debugging asserts are not compliant with XLA then they + # have to be disabled in other modes than eager. + if tf.executing_eagerly(): + tf.debugging.assert_equal( + shape_list(attn_weights), + [bsz * self.num_heads, tgt_len, src_len], + message=f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {shape_list(attn_weights)}", + ) if attention_mask is not None: - tf.debugging.assert_equal( - shape_list(attention_mask), - [bsz, 1, tgt_len, src_len], - message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}", - ) + # The tf.debugging asserts are not compliant with XLA then they + # have to be disabled in other modes than eager. + if tf.executing_eagerly(): + tf.debugging.assert_equal( + shape_list(attention_mask), + [bsz, 1, tgt_len, src_len], + message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}", + ) + + attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype) attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) attn_weights = tf.nn.softmax(attn_weights, axis=-1) if layer_head_mask is not None: - tf.debugging.assert_equal( - shape_list(layer_head_mask), - [self.num_heads], - message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}", - ) + # The tf.debugging asserts are not compliant with XLA then they + # have to be disabled in other modes than eager. + if tf.executing_eagerly(): + tf.debugging.assert_equal( + shape_list(layer_head_mask), + [self.num_heads], + message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}", + ) + attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape( attn_weights, (bsz, self.num_heads, tgt_len, src_len) ) attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) attn_probs = self.dropout(attn_weights, training=training) - attn_output = tf.matmul(attn_probs, value_states) - tf.debugging.assert_equal( - shape_list(attn_output), - [bsz * self.num_heads, tgt_len, self.head_dim], - message=f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {shape_list(attn_output)}", - ) + # The tf.debugging asserts are not compliant with XLA then they + # have to be disabled in other modes than eager. + if tf.executing_eagerly(): + tf.debugging.assert_equal( + shape_list(attn_output), + [bsz * self.num_heads, tgt_len, self.head_dim], + message=f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {shape_list(attn_output)}", + ) attn_output = tf.transpose( tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3) @@ -299,11 +309,16 @@ class TFMBartEncoderLayer(tf.keras.layers.Layer): hidden_states, self_attn_weights, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask ) - tf.debugging.assert_equal( - shape_list(hidden_states), - shape_list(residual), - message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}", - ) + + # The tf.debugging asserts are not compliant with XLA then they + # have to be disabled in other modes than eager. + if tf.executing_eagerly(): + tf.debugging.assert_equal( + shape_list(hidden_states), + shape_list(residual), + message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}", + ) + hidden_states = self.dropout(hidden_states, training=training) hidden_states = residual + hidden_states @@ -731,12 +746,15 @@ class TFMBartEncoder(tf.keras.layers.Layer): all_attentions = () if inputs["output_attentions"] else None # check if head_mask has a correct number of layers specified if desired - if inputs["head_mask"] is not None: + # The tf.debugging asserts are not compliant with XLA then they + # have to be disabled in other modes than eager. + if inputs["head_mask"] is not None and tf.executing_eagerly(): tf.debugging.assert_equal( shape_list(inputs["head_mask"])[0], len(self.layers), message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.", ) + # encoder layers for idx, encoder_layer in enumerate(self.layers): @@ -956,12 +974,15 @@ class TFMBartDecoder(tf.keras.layers.Layer): present_key_values = () # check if head_mask has a correct number of layers specified if desired - if inputs["head_mask"] is not None: + # The tf.debugging asserts are not compliant with XLA then they + # have to be disabled in other modes than eager. + if inputs["head_mask"] is not None and tf.executing_eagerly(): tf.debugging.assert_equal( shape_list(inputs["head_mask"])[0], len(self.layers), message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.", ) + for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) if inputs["output_hidden_states"]: diff --git a/src/transformers/models/pegasus/modeling_tf_pegasus.py b/src/transformers/models/pegasus/modeling_tf_pegasus.py index a94addbb21..536914688a 100644 --- a/src/transformers/models/pegasus/modeling_tf_pegasus.py +++ b/src/transformers/models/pegasus/modeling_tf_pegasus.py @@ -62,8 +62,7 @@ LARGE_NEGATIVE = -1e8 # Copied from transformers.models.bart.modeling_tf_bart.shift_tokens_right def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int): - shifted_input_ids = tf.cast(input_ids, tf.int32) - shifted_input_ids = tf.roll(shifted_input_ids, 1, axis=-1) + shifted_input_ids = tf.roll(input_ids, 1, axis=-1) start_tokens = tf.fill((shape_list(shifted_input_ids)[0], 1), decoder_start_token_id) shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, 1:]], -1) # replace possible -100 values in labels by `pad_token_id` @@ -71,12 +70,13 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids ) - # "Verify that `labels` has only positive values and -100" - assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.cast(0, tf.int32)) + if tf.executing_eagerly(): + # "Verify that `labels` has only positive values and -100" + assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0)) - # Make sure the assertion op is called by wrapping the result in an identity no-op - with tf.control_dependencies([assert_gte0]): - shifted_input_ids = tf.identity(shifted_input_ids) + # Make sure the assertion op is called by wrapping the result in an identity no-op + with tf.control_dependencies([assert_gte0]): + shifted_input_ids = tf.identity(shifted_input_ids) return shifted_input_ids @@ -87,14 +87,13 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - mask = tf.ones((tgt_len, tgt_len), dtype=tf.float32) * LARGE_NEGATIVE + mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE mask_cond = tf.range(shape_list(mask)[-1]) mask = tf.where(mask_cond < tf.reshape(mask_cond + 1, (shape_list(mask)[-1], 1)), 0.0, mask) - mask = tf.cast(mask, tf.float32) if past_key_values_length > 0: - mask = tf.concat([tf.zeros((tgt_len, past_key_values_length), dtype=tf.float32), mask], axis=-1) + mask = tf.concat([tf.zeros((tgt_len, past_key_values_length)), mask], axis=-1) return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1)) @@ -106,33 +105,43 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values """ src_len = shape_list(mask)[1] tgt_len = tgt_len if tgt_len is not None else src_len - expanded_mask = tf.cast(tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)), tf.float32) + one_cst = tf.constant(1.0) + mask = tf.cast(mask, dtype=one_cst.dtype) + expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)) - return (1.0 - expanded_mask) * LARGE_NEGATIVE + return (one_cst - expanded_mask) * LARGE_NEGATIVE # Copied from transformers.models.marian.modeling_tf_marian.TFMarianSinusoidalPositionalEmbedding with Marian->Pegasus -class TFPegasusSinusoidalPositionalEmbedding(tf.keras.layers.Embedding): +class TFPegasusSinusoidalPositionalEmbedding(tf.keras.layers.Layer): """This module produces sinusoidal positional embeddings of any length.""" def __init__(self, num_positions: int, embedding_dim: int, **kwargs): + super().__init__(**kwargs) if embedding_dim % 2 != 0: raise NotImplementedError(f"odd embedding_dim {embedding_dim} not supported") - super().__init__( - num_positions, - embedding_dim, - **kwargs, - ) + + self.embedding_dim = embedding_dim + self.num_positions = num_positions def build(self, input_shape: tf.TensorShape): """ Build shared token embedding layer Shared weights logic adapted from https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24 """ - super().build(input_shape) # Instantiates self.weight so it can be loaded - weight: np.ndarray = self._init_weight(self.input_dim, self.output_dim) - self.set_weights([weight]) # overwrite self.weight to correct value + + weight = self._init_weight(self.num_positions, self.embedding_dim) + + self.weight = self.add_weight( + name="embeddings", + shape=[self.num_positions, self.embedding_dim], + ) + weight = tf.cast(weight, dtype=self.weight.dtype) + + self.weight.assign(weight) + + super().build(input_shape) @staticmethod def _init_weight(n_pos: int, dim: int): @@ -147,7 +156,7 @@ class TFPegasusSinusoidalPositionalEmbedding(tf.keras.layers.Embedding): position_enc[:, 0 : dim // 2] = np.sin(position_enc[:, 0::2]) position_enc[:, dim // 2 :] = np.cos(position_enc[:, 1::2]) # convert to tensor - table = tf.convert_to_tensor(position_enc, dtype=tf.float32) + table = tf.convert_to_tensor(position_enc) tf.stop_gradient(table) return table @@ -155,10 +164,8 @@ class TFPegasusSinusoidalPositionalEmbedding(tf.keras.layers.Embedding): """Input is expected to be of size [bsz x seqlen].""" bsz, seq_len = input_shape[:2] - positions = tf.range( - past_key_values_length, seq_len + past_key_values_length, delta=1, dtype=tf.int32, name="range" - ) - return super().call(positions) + positions = tf.range(past_key_values_length, seq_len + past_key_values_length, delta=1, name="range") + return tf.gather(self.weight, positions) # Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention with Bart->Pegasus @@ -248,43 +255,57 @@ class TFPegasusAttention(tf.keras.layers.Layer): src_len = shape_list(key_states)[1] attn_weights = tf.matmul(query_states, key_states, transpose_b=True) - tf.debugging.assert_equal( - shape_list(attn_weights), - [bsz * self.num_heads, tgt_len, src_len], - message=f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {shape_list(attn_weights)}", - ) + # The tf.debugging asserts are not compliant with XLA then they + # have to be disabled in other modes than eager. + if tf.executing_eagerly(): + tf.debugging.assert_equal( + shape_list(attn_weights), + [bsz * self.num_heads, tgt_len, src_len], + message=f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {shape_list(attn_weights)}", + ) if attention_mask is not None: - tf.debugging.assert_equal( - shape_list(attention_mask), - [bsz, 1, tgt_len, src_len], - message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}", - ) + # The tf.debugging asserts are not compliant with XLA then they + # have to be disabled in other modes than eager. + if tf.executing_eagerly(): + tf.debugging.assert_equal( + shape_list(attention_mask), + [bsz, 1, tgt_len, src_len], + message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}", + ) + + attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype) attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) attn_weights = tf.nn.softmax(attn_weights, axis=-1) if layer_head_mask is not None: - tf.debugging.assert_equal( - shape_list(layer_head_mask), - [self.num_heads], - message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}", - ) + # The tf.debugging asserts are not compliant with XLA then they + # have to be disabled in other modes than eager. + if tf.executing_eagerly(): + tf.debugging.assert_equal( + shape_list(layer_head_mask), + [self.num_heads], + message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}", + ) + attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape( attn_weights, (bsz, self.num_heads, tgt_len, src_len) ) attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) attn_probs = self.dropout(attn_weights, training=training) - attn_output = tf.matmul(attn_probs, value_states) - tf.debugging.assert_equal( - shape_list(attn_output), - [bsz * self.num_heads, tgt_len, self.head_dim], - message=f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {shape_list(attn_output)}", - ) + # The tf.debugging asserts are not compliant with XLA then they + # have to be disabled in other modes than eager. + if tf.executing_eagerly(): + tf.debugging.assert_equal( + shape_list(attn_output), + [bsz * self.num_heads, tgt_len, self.head_dim], + message=f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {shape_list(attn_output)}", + ) attn_output = tf.transpose( tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3) @@ -327,11 +348,16 @@ class TFPegasusEncoderLayer(tf.keras.layers.Layer): hidden_states, self_attn_weights, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask ) - tf.debugging.assert_equal( - shape_list(hidden_states), - shape_list(residual), - message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}", - ) + + # The tf.debugging asserts are not compliant with XLA then they + # have to be disabled in other modes than eager. + if tf.executing_eagerly(): + tf.debugging.assert_equal( + shape_list(hidden_states), + shape_list(residual), + message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}", + ) + hidden_states = self.dropout(hidden_states, training=training) hidden_states = residual + hidden_states @@ -750,12 +776,15 @@ class TFPegasusEncoder(tf.keras.layers.Layer): all_attentions = () if inputs["output_attentions"] else None # check if head_mask has a correct number of layers specified if desired - if inputs["head_mask"] is not None: + # The tf.debugging asserts are not compliant with XLA then they + # have to be disabled in other modes than eager. + if inputs["head_mask"] is not None and tf.executing_eagerly(): tf.debugging.assert_equal( shape_list(inputs["head_mask"])[0], len(self.layers), message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.", ) + # encoder layers for idx, encoder_layer in enumerate(self.layers): @@ -972,12 +1001,15 @@ class TFPegasusDecoder(tf.keras.layers.Layer): present_key_values = () # check if head_mask has a correct number of layers specified if desired - if inputs["head_mask"] is not None: + # The tf.debugging asserts are not compliant with XLA then they + # have to be disabled in other modes than eager. + if inputs["head_mask"] is not None and tf.executing_eagerly(): tf.debugging.assert_equal( shape_list(inputs["head_mask"])[0], len(self.layers), message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.", ) + for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) if inputs["output_hidden_states"]: diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py index bd2a608e13..e11d14b39d 100644 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py @@ -1512,8 +1512,7 @@ LARGE_NEGATIVE = -1e8 def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int): - shifted_input_ids = tf.cast(input_ids, tf.int32) - shifted_input_ids = tf.roll(shifted_input_ids, 1, axis=-1) + shifted_input_ids = tf.roll(input_ids, 1, axis=-1) start_tokens = tf.fill((shape_list(shifted_input_ids)[0], 1), decoder_start_token_id) shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, 1:]], -1) # replace possible -100 values in labels by `pad_token_id` @@ -1521,12 +1520,13 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids ) - # "Verify that `labels` has only positive values and -100" - assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.cast(0, tf.int32)) + if tf.executing_eagerly(): + # "Verify that `labels` has only positive values and -100" + assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0)) - # Make sure the assertion op is called by wrapping the result in an identity no-op - with tf.control_dependencies([assert_gte0]): - shifted_input_ids = tf.identity(shifted_input_ids) + # Make sure the assertion op is called by wrapping the result in an identity no-op + with tf.control_dependencies([assert_gte0]): + shifted_input_ids = tf.identity(shifted_input_ids) return shifted_input_ids @@ -1536,15 +1536,14 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - mask = tf.ones((tgt_len, tgt_len), dtype=tf.float32) * LARGE_NEGATIVE + mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE mask_cond = tf.range(shape_list(mask)[-1]) mask = tf.where(mask_cond < tf.reshape(mask_cond + 1, (shape_list(mask)[-1], 1)), 0.0, mask) - mask = tf.cast(mask, tf.float32) if past_key_values_length > 0: - mask = tf.concat([tf.zeros((tgt_len, past_key_values_length), dtype=tf.float32), mask], axis=-1) - + mask = tf.concat([tf.zeros((tgt_len, past_key_values_length)), mask], axis=-1) + return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1)) @@ -1554,9 +1553,11 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values """ src_len = shape_list(mask)[1] tgt_len = tgt_len if tgt_len is not None else src_len - expanded_mask = tf.cast(tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)), tf.float32) + one_cst = tf.constant(1.0) + mask = tf.cast(mask, dtype=one_cst.dtype) + expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)) - return (1.0 - expanded_mask) * LARGE_NEGATIVE + return (one_cst - expanded_mask) * LARGE_NEGATIVE class TF{{cookiecutter.camelcase_modelname}}LearnedPositionalEmbedding(TFSharedEmbeddings): @@ -1573,7 +1574,7 @@ class TF{{cookiecutter.camelcase_modelname}}LearnedPositionalEmbedding(TFSharedE bsz, seq_len = input_shape[:2] positions = tf.range( - past_key_values_length, seq_len + past_key_values_length, delta=1, dtype=tf.int32, name="range" + past_key_values_length, seq_len + past_key_values_length, delta=1, name="range" ) return super().call(positions) @@ -1663,18 +1664,25 @@ class TF{{cookiecutter.camelcase_modelname}}Attention(tf.keras.layers.Layer): src_len = shape_list(key_states)[1] attn_weights = tf.matmul(query_states, key_states, transpose_b=True) - tf.debugging.assert_equal( - shape_list(attn_weights), - [bsz * self.num_heads, tgt_len, src_len], - message=f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {shape_list(attn_weights)}", - ) + # The tf.debugging asserts are not compliant with XLA then they + # have to be disabled in other modes than eager. + if tf.executing_eagerly(): + tf.debugging.assert_equal( + shape_list(attn_weights), + [bsz * self.num_heads, tgt_len, src_len], + message=f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {shape_list(attn_weights)}", + ) if attention_mask is not None: - tf.debugging.assert_equal( - shape_list(attention_mask), - [bsz, 1, tgt_len, src_len], - message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}", - ) + # The tf.debugging asserts are not compliant with XLA then they + # have to be disabled in other modes than eager. + if tf.executing_eagerly(): + tf.debugging.assert_equal( + shape_list(attention_mask), + [bsz, 1, tgt_len, src_len], + message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}", + ) + attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) @@ -1684,11 +1692,14 @@ class TF{{cookiecutter.camelcase_modelname}}Attention(tf.keras.layers.Layer): attn_output = tf.matmul(attn_probs, value_states) - tf.debugging.assert_equal( - shape_list(attn_output), - [bsz * self.num_heads, tgt_len, self.head_dim], - message=f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {shape_list(attn_output)}", - ) + # The tf.debugging asserts are not compliant with XLA then they + # have to be disabled in other modes than eager. + if tf.executing_eagerly(): + tf.debugging.assert_equal( + shape_list(attn_output), + [bsz * self.num_heads, tgt_len, self.head_dim], + message=f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {shape_list(attn_output)}", + ) attn_output = tf.transpose( tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3) @@ -1727,11 +1738,16 @@ class TF{{cookiecutter.camelcase_modelname}}EncoderLayer(tf.keras.layers.Layer): hidden_states, self_attn_weights, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask ) - tf.debugging.assert_equal( - shape_list(hidden_states), - shape_list(residual), - message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}", - ) + + # The tf.debugging asserts are not compliant with XLA then they + # have to be disabled in other modes than eager. + if tf.executing_eagerly(): + tf.debugging.assert_equal( + shape_list(hidden_states), + shape_list(residual), + message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}", + ) + hidden_states = self.dropout(hidden_states, training=training) hidden_states = residual + hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) @@ -2352,7 +2368,7 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer): axis=-1, ) else: - attention_mask = tf.ones((input_shape[0], input_shape[1] + past_key_values_length), dtype=tf.int32) + attention_mask = tf.ones((input_shape[0], input_shape[1] + past_key_values_length)) return attention_mask, combined_attention_mask diff --git a/tests/test_modeling_tf_bart.py b/tests/test_modeling_tf_bart.py index 04ecccec0f..3aef4c03f9 100644 --- a/tests/test_modeling_tf_bart.py +++ b/tests/test_modeling_tf_bart.py @@ -279,14 +279,6 @@ class TFBartModelTest(TFModelTesterMixin, unittest.TestCase): # This test is too long (>30sec) and makes fail the CI pass - def test_mixed_precision(self): - # TODO JP: Make BART float16 compliant - pass - - def test_xla_mode(self): - # TODO JP: Make BART XLA compliant - pass - def _assert_tensors_equal(a, b, atol=1e-12, prefix=""): """If tensors not close, or a and b arent both tensors, raise a nice Assertion error.""" diff --git a/tests/test_modeling_tf_blenderbot.py b/tests/test_modeling_tf_blenderbot.py index 9b99f2aa18..050a223f0e 100644 --- a/tests/test_modeling_tf_blenderbot.py +++ b/tests/test_modeling_tf_blenderbot.py @@ -214,14 +214,6 @@ class TFBlenderbotModelTest(TFModelTesterMixin, unittest.TestCase): # This test is too long (>30sec) and makes fail the CI pass - def test_mixed_precision(self): - # TODO JP: Make Blenderbot float16 compliant - pass - - def test_xla_mode(self): - # TODO JP: Make Blenderbot XLA compliant - pass - def test_resize_token_embeddings(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/test_modeling_tf_blenderbot_small.py b/tests/test_modeling_tf_blenderbot_small.py index 6fe2e18bea..850fb3357b 100644 --- a/tests/test_modeling_tf_blenderbot_small.py +++ b/tests/test_modeling_tf_blenderbot_small.py @@ -279,14 +279,6 @@ class TFBlenderbotSmallModelTest(TFModelTesterMixin, unittest.TestCase): # This test is too long (>30sec) and makes fail the CI pass - def test_mixed_precision(self): - # TODO JP: Make Blenderbot Small float16 compliant - pass - - def test_xla_mode(self): - # TODO JP: Make Blenderbot Small XLA compliant - pass - def _assert_tensors_equal(a, b, atol=1e-12, prefix=""): """If tensors not close, or a and b arent both tensors, raise a nice Assertion error.""" diff --git a/tests/test_modeling_tf_marian.py b/tests/test_modeling_tf_marian.py index 49b2d366cc..a49c47bb19 100644 --- a/tests/test_modeling_tf_marian.py +++ b/tests/test_modeling_tf_marian.py @@ -247,14 +247,6 @@ class TFMarianModelTest(TFModelTesterMixin, unittest.TestCase): # This test is too long (>30sec) and makes fail the CI pass - def test_mixed_precision(self): - # TODO JP: Make Marian float16 compliant - pass - - def test_xla_mode(self): - # TODO JP: Make Marian XLA compliant - pass - def test_resize_token_embeddings(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/test_modeling_tf_mbart.py b/tests/test_modeling_tf_mbart.py index c912c00940..4891b00c38 100644 --- a/tests/test_modeling_tf_mbart.py +++ b/tests/test_modeling_tf_mbart.py @@ -214,18 +214,6 @@ class TFMBartModelTest(TFModelTesterMixin, unittest.TestCase): name = model.get_bias() assert name is None - def test_saved_model_creation(self): - # This test is too long (>30sec) and makes fail the CI - pass - - def test_mixed_precision(self): - # TODO JP: Make MBart float16 compliant - pass - - def test_xla_mode(self): - # TODO JP: Make MBart XLA compliant - pass - def test_resize_token_embeddings(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -289,6 +277,10 @@ class TFMBartModelTest(TFModelTesterMixin, unittest.TestCase): models_equal = False self.assertTrue(models_equal) + def test_saved_model_creation(self): + # This test is too long (>30sec) and makes fail the CI + pass + def _assert_tensors_equal(a, b, atol=1e-12, prefix=""): """If tensors not close, or a and b arent both tensors, raise a nice Assertion error.""" diff --git a/tests/test_modeling_tf_pegasus.py b/tests/test_modeling_tf_pegasus.py index 23a8f6aa6c..46ff69ec16 100644 --- a/tests/test_modeling_tf_pegasus.py +++ b/tests/test_modeling_tf_pegasus.py @@ -245,14 +245,6 @@ class TFPegasusModelTest(TFModelTesterMixin, unittest.TestCase): # This test is too long (>30sec) and makes fail the CI pass - def test_mixed_precision(self): - # TODO JP: Make Pegasus float16 compliant - pass - - def test_xla_mode(self): - # TODO JP: Make Pegasus XLA compliant - pass - def test_resize_token_embeddings(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()