Making TF BART-like models XLA and AMP compliant (#10191)
* Update BART * Update Blenderbot * Update BlenderbotSmall * Update Marian * Update MBart * Update MBart * Update Pegasus * Update template * Fix Marian and Pegasus * Apply style * Default initializer * Default initializer * Default initializer * Remove int32 casts * Fix template * Remove more cast
This commit is contained in:
parent
8d79e5ca49
commit
83d803ba02
|
@ -60,8 +60,7 @@ LARGE_NEGATIVE = -1e8
|
|||
|
||||
|
||||
def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int):
|
||||
shifted_input_ids = tf.cast(input_ids, tf.int32)
|
||||
shifted_input_ids = tf.roll(shifted_input_ids, 1, axis=-1)
|
||||
shifted_input_ids = tf.roll(input_ids, 1, axis=-1)
|
||||
start_tokens = tf.fill((shape_list(shifted_input_ids)[0], 1), decoder_start_token_id)
|
||||
shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, 1:]], -1)
|
||||
# replace possible -100 values in labels by `pad_token_id`
|
||||
|
@ -69,8 +68,9 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
|
|||
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
|
||||
)
|
||||
|
||||
if tf.executing_eagerly():
|
||||
# "Verify that `labels` has only positive values and -100"
|
||||
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.cast(0, tf.int32))
|
||||
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0))
|
||||
|
||||
# Make sure the assertion op is called by wrapping the result in an identity no-op
|
||||
with tf.control_dependencies([assert_gte0]):
|
||||
|
@ -84,14 +84,13 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i
|
|||
Make causal mask used for bi-directional self-attention.
|
||||
"""
|
||||
bsz, tgt_len = input_ids_shape
|
||||
mask = tf.ones((tgt_len, tgt_len), dtype=tf.float32) * LARGE_NEGATIVE
|
||||
mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE
|
||||
mask_cond = tf.range(shape_list(mask)[-1])
|
||||
|
||||
mask = tf.where(mask_cond < tf.reshape(mask_cond + 1, (shape_list(mask)[-1], 1)), 0.0, mask)
|
||||
mask = tf.cast(mask, tf.float32)
|
||||
|
||||
if past_key_values_length > 0:
|
||||
mask = tf.concat([tf.zeros((tgt_len, past_key_values_length), dtype=tf.float32), mask], axis=-1)
|
||||
mask = tf.concat([tf.zeros((tgt_len, past_key_values_length)), mask], axis=-1)
|
||||
|
||||
return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1))
|
||||
|
||||
|
@ -102,9 +101,11 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values
|
|||
"""
|
||||
src_len = shape_list(mask)[1]
|
||||
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||
expanded_mask = tf.cast(tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)), tf.float32)
|
||||
one_cst = tf.constant(1.0)
|
||||
mask = tf.cast(mask, dtype=one_cst.dtype)
|
||||
expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1))
|
||||
|
||||
return (1.0 - expanded_mask) * LARGE_NEGATIVE
|
||||
return (one_cst - expanded_mask) * LARGE_NEGATIVE
|
||||
|
||||
|
||||
class TFBartLearnedPositionalEmbedding(TFSharedEmbeddings):
|
||||
|
@ -123,9 +124,7 @@ class TFBartLearnedPositionalEmbedding(TFSharedEmbeddings):
|
|||
"""Input is expected to be of size [bsz x seqlen]."""
|
||||
bsz, seq_len = input_shape[:2]
|
||||
|
||||
positions = tf.range(
|
||||
past_key_values_length, seq_len + past_key_values_length, delta=1, dtype=tf.int32, name="range"
|
||||
)
|
||||
positions = tf.range(past_key_values_length, seq_len + past_key_values_length, delta=1, name="range")
|
||||
return super().call(positions + self.offset)
|
||||
|
||||
|
||||
|
@ -215,6 +214,9 @@ class TFBartAttention(tf.keras.layers.Layer):
|
|||
src_len = shape_list(key_states)[1]
|
||||
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
||||
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_weights),
|
||||
[bsz * self.num_heads, tgt_len, src_len],
|
||||
|
@ -222,31 +224,42 @@ class TFBartAttention(tf.keras.layers.Layer):
|
|||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attention_mask),
|
||||
[bsz, 1, tgt_len, src_len],
|
||||
message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}",
|
||||
)
|
||||
|
||||
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
|
||||
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
|
||||
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
|
||||
|
||||
attn_weights = tf.nn.softmax(attn_weights, axis=-1)
|
||||
|
||||
if layer_head_mask is not None:
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(layer_head_mask),
|
||||
[self.num_heads],
|
||||
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
|
||||
)
|
||||
|
||||
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
||||
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
|
||||
)
|
||||
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
|
||||
|
||||
attn_probs = self.dropout(attn_weights, training=training)
|
||||
|
||||
attn_output = tf.matmul(attn_probs, value_states)
|
||||
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_output),
|
||||
[bsz * self.num_heads, tgt_len, self.head_dim],
|
||||
|
@ -292,11 +305,16 @@ class TFBartEncoderLayer(tf.keras.layers.Layer):
|
|||
hidden_states, self_attn_weights, _ = self.self_attn(
|
||||
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
|
||||
)
|
||||
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(hidden_states),
|
||||
shape_list(residual),
|
||||
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
|
||||
)
|
||||
|
||||
hidden_states = self.dropout(hidden_states, training=training)
|
||||
hidden_states = residual + hidden_states
|
||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||
|
@ -717,12 +735,15 @@ class TFBartEncoder(tf.keras.layers.Layer):
|
|||
all_attentions = () if inputs["output_attentions"] else None
|
||||
|
||||
# check if head_mask has a correct number of layers specified if desired
|
||||
if inputs["head_mask"] is not None:
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if inputs["head_mask"] is not None and tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(inputs["head_mask"])[0],
|
||||
len(self.layers),
|
||||
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
|
||||
)
|
||||
|
||||
# encoder layers
|
||||
for idx, encoder_layer in enumerate(self.layers):
|
||||
|
||||
|
@ -933,12 +954,15 @@ class TFBartDecoder(tf.keras.layers.Layer):
|
|||
present_key_values = () if inputs["use_cache"] else None
|
||||
|
||||
# check if head_mask has a correct number of layers specified if desired
|
||||
if inputs["head_mask"] is not None:
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if inputs["head_mask"] is not None and tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(inputs["head_mask"])[0],
|
||||
len(self.layers),
|
||||
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
|
||||
)
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||
if inputs["output_hidden_states"]:
|
||||
|
|
|
@ -63,8 +63,7 @@ LARGE_NEGATIVE = -1e8
|
|||
|
||||
# Copied from transformers.models.bart.modeling_tf_bart.shift_tokens_right
|
||||
def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int):
|
||||
shifted_input_ids = tf.cast(input_ids, tf.int32)
|
||||
shifted_input_ids = tf.roll(shifted_input_ids, 1, axis=-1)
|
||||
shifted_input_ids = tf.roll(input_ids, 1, axis=-1)
|
||||
start_tokens = tf.fill((shape_list(shifted_input_ids)[0], 1), decoder_start_token_id)
|
||||
shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, 1:]], -1)
|
||||
# replace possible -100 values in labels by `pad_token_id`
|
||||
|
@ -72,8 +71,9 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
|
|||
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
|
||||
)
|
||||
|
||||
if tf.executing_eagerly():
|
||||
# "Verify that `labels` has only positive values and -100"
|
||||
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.cast(0, tf.int32))
|
||||
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0))
|
||||
|
||||
# Make sure the assertion op is called by wrapping the result in an identity no-op
|
||||
with tf.control_dependencies([assert_gte0]):
|
||||
|
@ -88,14 +88,13 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i
|
|||
Make causal mask used for bi-directional self-attention.
|
||||
"""
|
||||
bsz, tgt_len = input_ids_shape
|
||||
mask = tf.ones((tgt_len, tgt_len), dtype=tf.float32) * LARGE_NEGATIVE
|
||||
mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE
|
||||
mask_cond = tf.range(shape_list(mask)[-1])
|
||||
|
||||
mask = tf.where(mask_cond < tf.reshape(mask_cond + 1, (shape_list(mask)[-1], 1)), 0.0, mask)
|
||||
mask = tf.cast(mask, tf.float32)
|
||||
|
||||
if past_key_values_length > 0:
|
||||
mask = tf.concat([tf.zeros((tgt_len, past_key_values_length), dtype=tf.float32), mask], axis=-1)
|
||||
mask = tf.concat([tf.zeros((tgt_len, past_key_values_length)), mask], axis=-1)
|
||||
|
||||
return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1))
|
||||
|
||||
|
@ -107,9 +106,11 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values
|
|||
"""
|
||||
src_len = shape_list(mask)[1]
|
||||
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||
expanded_mask = tf.cast(tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)), tf.float32)
|
||||
one_cst = tf.constant(1.0)
|
||||
mask = tf.cast(mask, dtype=one_cst.dtype)
|
||||
expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1))
|
||||
|
||||
return (1.0 - expanded_mask) * LARGE_NEGATIVE
|
||||
return (one_cst - expanded_mask) * LARGE_NEGATIVE
|
||||
|
||||
|
||||
class TFBlenderbotLearnedPositionalEmbedding(TFSharedEmbeddings):
|
||||
|
@ -125,9 +126,7 @@ class TFBlenderbotLearnedPositionalEmbedding(TFSharedEmbeddings):
|
|||
"""Input is expected to be of size [bsz x seqlen]."""
|
||||
bsz, seq_len = input_shape[:2]
|
||||
|
||||
positions = tf.range(
|
||||
past_key_values_length, seq_len + past_key_values_length, delta=1, dtype=tf.int32, name="range"
|
||||
)
|
||||
positions = tf.range(past_key_values_length, seq_len + past_key_values_length, delta=1, name="range")
|
||||
return super().call(positions)
|
||||
|
||||
|
||||
|
@ -218,6 +217,9 @@ class TFBlenderbotAttention(tf.keras.layers.Layer):
|
|||
src_len = shape_list(key_states)[1]
|
||||
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
||||
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_weights),
|
||||
[bsz * self.num_heads, tgt_len, src_len],
|
||||
|
@ -225,31 +227,42 @@ class TFBlenderbotAttention(tf.keras.layers.Layer):
|
|||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attention_mask),
|
||||
[bsz, 1, tgt_len, src_len],
|
||||
message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}",
|
||||
)
|
||||
|
||||
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
|
||||
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
|
||||
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
|
||||
|
||||
attn_weights = tf.nn.softmax(attn_weights, axis=-1)
|
||||
|
||||
if layer_head_mask is not None:
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(layer_head_mask),
|
||||
[self.num_heads],
|
||||
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
|
||||
)
|
||||
|
||||
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
||||
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
|
||||
)
|
||||
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
|
||||
|
||||
attn_probs = self.dropout(attn_weights, training=training)
|
||||
|
||||
attn_output = tf.matmul(attn_probs, value_states)
|
||||
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_output),
|
||||
[bsz * self.num_heads, tgt_len, self.head_dim],
|
||||
|
@ -297,11 +310,16 @@ class TFBlenderbotEncoderLayer(tf.keras.layers.Layer):
|
|||
hidden_states, self_attn_weights, _ = self.self_attn(
|
||||
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
|
||||
)
|
||||
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(hidden_states),
|
||||
shape_list(residual),
|
||||
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
|
||||
)
|
||||
|
||||
hidden_states = self.dropout(hidden_states, training=training)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
|
@ -719,12 +737,15 @@ class TFBlenderbotEncoder(tf.keras.layers.Layer):
|
|||
all_attentions = () if inputs["output_attentions"] else None
|
||||
|
||||
# check if head_mask has a correct number of layers specified if desired
|
||||
if inputs["head_mask"] is not None:
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if inputs["head_mask"] is not None and tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(inputs["head_mask"])[0],
|
||||
len(self.layers),
|
||||
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
|
||||
)
|
||||
|
||||
# encoder layers
|
||||
for idx, encoder_layer in enumerate(self.layers):
|
||||
|
||||
|
@ -943,12 +964,15 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer):
|
|||
present_key_values = ()
|
||||
|
||||
# check if head_mask has a correct number of layers specified if desired
|
||||
if inputs["head_mask"] is not None:
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if inputs["head_mask"] is not None and tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(inputs["head_mask"])[0],
|
||||
len(self.layers),
|
||||
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
|
||||
)
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||
if inputs["output_hidden_states"]:
|
||||
|
|
|
@ -61,8 +61,7 @@ LARGE_NEGATIVE = -1e8
|
|||
|
||||
# Copied from transformers.models.bart.modeling_tf_bart.shift_tokens_right
|
||||
def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int):
|
||||
shifted_input_ids = tf.cast(input_ids, tf.int32)
|
||||
shifted_input_ids = tf.roll(shifted_input_ids, 1, axis=-1)
|
||||
shifted_input_ids = tf.roll(input_ids, 1, axis=-1)
|
||||
start_tokens = tf.fill((shape_list(shifted_input_ids)[0], 1), decoder_start_token_id)
|
||||
shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, 1:]], -1)
|
||||
# replace possible -100 values in labels by `pad_token_id`
|
||||
|
@ -70,8 +69,9 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
|
|||
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
|
||||
)
|
||||
|
||||
if tf.executing_eagerly():
|
||||
# "Verify that `labels` has only positive values and -100"
|
||||
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.cast(0, tf.int32))
|
||||
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0))
|
||||
|
||||
# Make sure the assertion op is called by wrapping the result in an identity no-op
|
||||
with tf.control_dependencies([assert_gte0]):
|
||||
|
@ -86,14 +86,13 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i
|
|||
Make causal mask used for bi-directional self-attention.
|
||||
"""
|
||||
bsz, tgt_len = input_ids_shape
|
||||
mask = tf.ones((tgt_len, tgt_len), dtype=tf.float32) * LARGE_NEGATIVE
|
||||
mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE
|
||||
mask_cond = tf.range(shape_list(mask)[-1])
|
||||
|
||||
mask = tf.where(mask_cond < tf.reshape(mask_cond + 1, (shape_list(mask)[-1], 1)), 0.0, mask)
|
||||
mask = tf.cast(mask, tf.float32)
|
||||
|
||||
if past_key_values_length > 0:
|
||||
mask = tf.concat([tf.zeros((tgt_len, past_key_values_length), dtype=tf.float32), mask], axis=-1)
|
||||
mask = tf.concat([tf.zeros((tgt_len, past_key_values_length)), mask], axis=-1)
|
||||
|
||||
return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1))
|
||||
|
||||
|
@ -105,9 +104,11 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values
|
|||
"""
|
||||
src_len = shape_list(mask)[1]
|
||||
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||
expanded_mask = tf.cast(tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)), tf.float32)
|
||||
one_cst = tf.constant(1.0)
|
||||
mask = tf.cast(mask, dtype=one_cst.dtype)
|
||||
expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1))
|
||||
|
||||
return (1.0 - expanded_mask) * LARGE_NEGATIVE
|
||||
return (one_cst - expanded_mask) * LARGE_NEGATIVE
|
||||
|
||||
|
||||
# Copied from transformers.models.blenderbot.modeling_tf_blenderbot.TFBlenderbotLearnedPositionalEmbedding with Blenderbot->BlenderbotSmall
|
||||
|
@ -124,9 +125,7 @@ class TFBlenderbotSmallLearnedPositionalEmbedding(TFSharedEmbeddings):
|
|||
"""Input is expected to be of size [bsz x seqlen]."""
|
||||
bsz, seq_len = input_shape[:2]
|
||||
|
||||
positions = tf.range(
|
||||
past_key_values_length, seq_len + past_key_values_length, delta=1, dtype=tf.int32, name="range"
|
||||
)
|
||||
positions = tf.range(past_key_values_length, seq_len + past_key_values_length, delta=1, name="range")
|
||||
return super().call(positions)
|
||||
|
||||
|
||||
|
@ -217,6 +216,9 @@ class TFBlenderbotSmallAttention(tf.keras.layers.Layer):
|
|||
src_len = shape_list(key_states)[1]
|
||||
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
||||
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_weights),
|
||||
[bsz * self.num_heads, tgt_len, src_len],
|
||||
|
@ -224,31 +226,42 @@ class TFBlenderbotSmallAttention(tf.keras.layers.Layer):
|
|||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attention_mask),
|
||||
[bsz, 1, tgt_len, src_len],
|
||||
message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}",
|
||||
)
|
||||
|
||||
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
|
||||
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
|
||||
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
|
||||
|
||||
attn_weights = tf.nn.softmax(attn_weights, axis=-1)
|
||||
|
||||
if layer_head_mask is not None:
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(layer_head_mask),
|
||||
[self.num_heads],
|
||||
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
|
||||
)
|
||||
|
||||
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
||||
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
|
||||
)
|
||||
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
|
||||
|
||||
attn_probs = self.dropout(attn_weights, training=training)
|
||||
|
||||
attn_output = tf.matmul(attn_probs, value_states)
|
||||
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_output),
|
||||
[bsz * self.num_heads, tgt_len, self.head_dim],
|
||||
|
@ -295,11 +308,16 @@ class TFBlenderbotSmallEncoderLayer(tf.keras.layers.Layer):
|
|||
hidden_states, self_attn_weights, _ = self.self_attn(
|
||||
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
|
||||
)
|
||||
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(hidden_states),
|
||||
shape_list(residual),
|
||||
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
|
||||
)
|
||||
|
||||
hidden_states = self.dropout(hidden_states, training=training)
|
||||
hidden_states = residual + hidden_states
|
||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||
|
@ -725,12 +743,15 @@ class TFBlenderbotSmallEncoder(tf.keras.layers.Layer):
|
|||
all_attentions = () if inputs["output_attentions"] else None
|
||||
|
||||
# check if head_mask has a correct number of layers specified if desired
|
||||
if inputs["head_mask"] is not None:
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if inputs["head_mask"] is not None and tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(inputs["head_mask"])[0],
|
||||
len(self.layers),
|
||||
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
|
||||
)
|
||||
|
||||
# encoder layers
|
||||
for idx, encoder_layer in enumerate(self.layers):
|
||||
|
||||
|
@ -946,12 +967,15 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
|
|||
present_key_values = ()
|
||||
|
||||
# check if head_mask has a correct number of layers specified if desired
|
||||
if inputs["head_mask"] is not None:
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if inputs["head_mask"] is not None and tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(inputs["head_mask"])[0],
|
||||
len(self.layers),
|
||||
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
|
||||
)
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||
if inputs["output_hidden_states"]:
|
||||
|
|
|
@ -62,8 +62,7 @@ LARGE_NEGATIVE = -1e8
|
|||
|
||||
# Copied from transformers.models.bart.modeling_tf_bart.shift_tokens_right
|
||||
def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int):
|
||||
shifted_input_ids = tf.cast(input_ids, tf.int32)
|
||||
shifted_input_ids = tf.roll(shifted_input_ids, 1, axis=-1)
|
||||
shifted_input_ids = tf.roll(input_ids, 1, axis=-1)
|
||||
start_tokens = tf.fill((shape_list(shifted_input_ids)[0], 1), decoder_start_token_id)
|
||||
shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, 1:]], -1)
|
||||
# replace possible -100 values in labels by `pad_token_id`
|
||||
|
@ -71,8 +70,9 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
|
|||
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
|
||||
)
|
||||
|
||||
if tf.executing_eagerly():
|
||||
# "Verify that `labels` has only positive values and -100"
|
||||
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.cast(0, tf.int32))
|
||||
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0))
|
||||
|
||||
# Make sure the assertion op is called by wrapping the result in an identity no-op
|
||||
with tf.control_dependencies([assert_gte0]):
|
||||
|
@ -87,14 +87,13 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i
|
|||
Make causal mask used for bi-directional self-attention.
|
||||
"""
|
||||
bsz, tgt_len = input_ids_shape
|
||||
mask = tf.ones((tgt_len, tgt_len), dtype=tf.float32) * LARGE_NEGATIVE
|
||||
mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE
|
||||
mask_cond = tf.range(shape_list(mask)[-1])
|
||||
|
||||
mask = tf.where(mask_cond < tf.reshape(mask_cond + 1, (shape_list(mask)[-1], 1)), 0.0, mask)
|
||||
mask = tf.cast(mask, tf.float32)
|
||||
|
||||
if past_key_values_length > 0:
|
||||
mask = tf.concat([tf.zeros((tgt_len, past_key_values_length), dtype=tf.float32), mask], axis=-1)
|
||||
mask = tf.concat([tf.zeros((tgt_len, past_key_values_length)), mask], axis=-1)
|
||||
|
||||
return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1))
|
||||
|
||||
|
@ -106,32 +105,42 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values
|
|||
"""
|
||||
src_len = shape_list(mask)[1]
|
||||
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||
expanded_mask = tf.cast(tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)), tf.float32)
|
||||
one_cst = tf.constant(1.0)
|
||||
mask = tf.cast(mask, dtype=one_cst.dtype)
|
||||
expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1))
|
||||
|
||||
return (1.0 - expanded_mask) * LARGE_NEGATIVE
|
||||
return (one_cst - expanded_mask) * LARGE_NEGATIVE
|
||||
|
||||
|
||||
class TFMarianSinusoidalPositionalEmbedding(tf.keras.layers.Embedding):
|
||||
class TFMarianSinusoidalPositionalEmbedding(tf.keras.layers.Layer):
|
||||
"""This module produces sinusoidal positional embeddings of any length."""
|
||||
|
||||
def __init__(self, num_positions: int, embedding_dim: int, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if embedding_dim % 2 != 0:
|
||||
raise NotImplementedError(f"odd embedding_dim {embedding_dim} not supported")
|
||||
super().__init__(
|
||||
num_positions,
|
||||
embedding_dim,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.embedding_dim = embedding_dim
|
||||
self.num_positions = num_positions
|
||||
|
||||
def build(self, input_shape: tf.TensorShape):
|
||||
"""
|
||||
Build shared token embedding layer Shared weights logic adapted from
|
||||
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
|
||||
"""
|
||||
super().build(input_shape) # Instantiates self.weight so it can be loaded
|
||||
weight: np.ndarray = self._init_weight(self.input_dim, self.output_dim)
|
||||
self.set_weights([weight]) # overwrite self.weight to correct value
|
||||
|
||||
weight = self._init_weight(self.num_positions, self.embedding_dim)
|
||||
|
||||
self.weight = self.add_weight(
|
||||
name="embeddings",
|
||||
shape=[self.num_positions, self.embedding_dim],
|
||||
)
|
||||
weight = tf.cast(weight, dtype=self.weight.dtype)
|
||||
|
||||
self.weight.assign(weight)
|
||||
|
||||
super().build(input_shape)
|
||||
|
||||
@staticmethod
|
||||
def _init_weight(n_pos: int, dim: int):
|
||||
|
@ -146,7 +155,7 @@ class TFMarianSinusoidalPositionalEmbedding(tf.keras.layers.Embedding):
|
|||
position_enc[:, 0 : dim // 2] = np.sin(position_enc[:, 0::2])
|
||||
position_enc[:, dim // 2 :] = np.cos(position_enc[:, 1::2])
|
||||
# convert to tensor
|
||||
table = tf.convert_to_tensor(position_enc, dtype=tf.float32)
|
||||
table = tf.convert_to_tensor(position_enc)
|
||||
tf.stop_gradient(table)
|
||||
return table
|
||||
|
||||
|
@ -154,10 +163,8 @@ class TFMarianSinusoidalPositionalEmbedding(tf.keras.layers.Embedding):
|
|||
"""Input is expected to be of size [bsz x seqlen]."""
|
||||
bsz, seq_len = input_shape[:2]
|
||||
|
||||
positions = tf.range(
|
||||
past_key_values_length, seq_len + past_key_values_length, delta=1, dtype=tf.int32, name="range"
|
||||
)
|
||||
return super().call(positions)
|
||||
positions = tf.range(past_key_values_length, seq_len + past_key_values_length, delta=1, name="range")
|
||||
return tf.gather(self.weight, positions)
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention with Bart->Marian
|
||||
|
@ -247,6 +254,9 @@ class TFMarianAttention(tf.keras.layers.Layer):
|
|||
src_len = shape_list(key_states)[1]
|
||||
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
||||
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_weights),
|
||||
[bsz * self.num_heads, tgt_len, src_len],
|
||||
|
@ -254,31 +264,42 @@ class TFMarianAttention(tf.keras.layers.Layer):
|
|||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attention_mask),
|
||||
[bsz, 1, tgt_len, src_len],
|
||||
message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}",
|
||||
)
|
||||
|
||||
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
|
||||
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
|
||||
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
|
||||
|
||||
attn_weights = tf.nn.softmax(attn_weights, axis=-1)
|
||||
|
||||
if layer_head_mask is not None:
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(layer_head_mask),
|
||||
[self.num_heads],
|
||||
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
|
||||
)
|
||||
|
||||
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
||||
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
|
||||
)
|
||||
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
|
||||
|
||||
attn_probs = self.dropout(attn_weights, training=training)
|
||||
|
||||
attn_output = tf.matmul(attn_probs, value_states)
|
||||
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_output),
|
||||
[bsz * self.num_heads, tgt_len, self.head_dim],
|
||||
|
@ -325,11 +346,16 @@ class TFMarianEncoderLayer(tf.keras.layers.Layer):
|
|||
hidden_states, self_attn_weights, _ = self.self_attn(
|
||||
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
|
||||
)
|
||||
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(hidden_states),
|
||||
shape_list(residual),
|
||||
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
|
||||
)
|
||||
|
||||
hidden_states = self.dropout(hidden_states, training=training)
|
||||
hidden_states = residual + hidden_states
|
||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||
|
@ -741,12 +767,15 @@ class TFMarianEncoder(tf.keras.layers.Layer):
|
|||
all_attentions = () if inputs["output_attentions"] else None
|
||||
|
||||
# check if head_mask has a correct number of layers specified if desired
|
||||
if inputs["head_mask"] is not None:
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if inputs["head_mask"] is not None and tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(inputs["head_mask"])[0],
|
||||
len(self.layers),
|
||||
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
|
||||
)
|
||||
|
||||
# encoder layers
|
||||
for idx, encoder_layer in enumerate(self.layers):
|
||||
|
||||
|
@ -960,12 +989,15 @@ class TFMarianDecoder(tf.keras.layers.Layer):
|
|||
present_key_values = ()
|
||||
|
||||
# check if head_mask has a correct number of layers specified if desired
|
||||
if inputs["head_mask"] is not None:
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if inputs["head_mask"] is not None and tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(inputs["head_mask"])[0],
|
||||
len(self.layers),
|
||||
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
|
||||
)
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||
if inputs["output_hidden_states"]:
|
||||
|
|
|
@ -64,19 +64,16 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int):
|
|||
Shift input ids one token to the right, and wrap the last non pad token (the <LID> token) Note that MBart does not
|
||||
have a single `decoder_start_token_id` in contrast to other Bart-like models.
|
||||
"""
|
||||
prev_output_tokens = tf.cast(input_ids, tf.int32)
|
||||
assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
|
||||
# replace possible -100 values in labels by `pad_token_id`
|
||||
prev_output_tokens = tf.where(
|
||||
prev_output_tokens == -100, tf.fill(shape_list(prev_output_tokens), pad_token_id), prev_output_tokens
|
||||
)
|
||||
input_ids = tf.where(input_ids == -100, tf.fill(shape_list(input_ids), pad_token_id), input_ids)
|
||||
language_id_index = (
|
||||
tf.reduce_sum(tf.cast(tf.math.not_equal(prev_output_tokens, pad_token_id), tf.int32), axis=-1) - 1
|
||||
tf.reduce_sum(tf.cast(tf.math.not_equal(input_ids, pad_token_id), dtype=input_ids.dtype), axis=-1) - 1
|
||||
)
|
||||
language_id_index = tf.stack([tf.range(shape_list(input_ids)[0]), language_id_index], axis=-1)
|
||||
languages_ids = tf.gather_nd(prev_output_tokens, language_id_index)
|
||||
languages_ids = tf.gather_nd(input_ids, language_id_index)
|
||||
|
||||
shifted_input_ids = tf.concat([tf.expand_dims(languages_ids, axis=-1), prev_output_tokens[:, :-1]], axis=-1)
|
||||
shifted_input_ids = tf.concat([tf.expand_dims(languages_ids, axis=-1), input_ids[:, :-1]], axis=-1)
|
||||
|
||||
return shifted_input_ids
|
||||
|
||||
|
@ -87,14 +84,13 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i
|
|||
Make causal mask used for bi-directional self-attention.
|
||||
"""
|
||||
bsz, tgt_len = input_ids_shape
|
||||
mask = tf.ones((tgt_len, tgt_len), dtype=tf.float32) * LARGE_NEGATIVE
|
||||
mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE
|
||||
mask_cond = tf.range(shape_list(mask)[-1])
|
||||
|
||||
mask = tf.where(mask_cond < tf.reshape(mask_cond + 1, (shape_list(mask)[-1], 1)), 0.0, mask)
|
||||
mask = tf.cast(mask, tf.float32)
|
||||
|
||||
if past_key_values_length > 0:
|
||||
mask = tf.concat([tf.zeros((tgt_len, past_key_values_length), dtype=tf.float32), mask], axis=-1)
|
||||
mask = tf.concat([tf.zeros((tgt_len, past_key_values_length)), mask], axis=-1)
|
||||
|
||||
return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1))
|
||||
|
||||
|
@ -106,9 +102,11 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values
|
|||
"""
|
||||
src_len = shape_list(mask)[1]
|
||||
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||
expanded_mask = tf.cast(tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)), tf.float32)
|
||||
one_cst = tf.constant(1.0)
|
||||
mask = tf.cast(mask, dtype=one_cst.dtype)
|
||||
expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1))
|
||||
|
||||
return (1.0 - expanded_mask) * LARGE_NEGATIVE
|
||||
return (one_cst - expanded_mask) * LARGE_NEGATIVE
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_tf_bart.TFBartLearnedPositionalEmbedding with Bart->MBart
|
||||
|
@ -128,9 +126,7 @@ class TFMBartLearnedPositionalEmbedding(TFSharedEmbeddings):
|
|||
"""Input is expected to be of size [bsz x seqlen]."""
|
||||
bsz, seq_len = input_shape[:2]
|
||||
|
||||
positions = tf.range(
|
||||
past_key_values_length, seq_len + past_key_values_length, delta=1, dtype=tf.int32, name="range"
|
||||
)
|
||||
positions = tf.range(past_key_values_length, seq_len + past_key_values_length, delta=1, name="range")
|
||||
return super().call(positions + self.offset)
|
||||
|
||||
|
||||
|
@ -221,6 +217,9 @@ class TFMBartAttention(tf.keras.layers.Layer):
|
|||
src_len = shape_list(key_states)[1]
|
||||
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
||||
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_weights),
|
||||
[bsz * self.num_heads, tgt_len, src_len],
|
||||
|
@ -228,31 +227,42 @@ class TFMBartAttention(tf.keras.layers.Layer):
|
|||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attention_mask),
|
||||
[bsz, 1, tgt_len, src_len],
|
||||
message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}",
|
||||
)
|
||||
|
||||
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
|
||||
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
|
||||
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
|
||||
|
||||
attn_weights = tf.nn.softmax(attn_weights, axis=-1)
|
||||
|
||||
if layer_head_mask is not None:
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(layer_head_mask),
|
||||
[self.num_heads],
|
||||
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
|
||||
)
|
||||
|
||||
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
||||
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
|
||||
)
|
||||
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
|
||||
|
||||
attn_probs = self.dropout(attn_weights, training=training)
|
||||
|
||||
attn_output = tf.matmul(attn_probs, value_states)
|
||||
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_output),
|
||||
[bsz * self.num_heads, tgt_len, self.head_dim],
|
||||
|
@ -299,11 +309,16 @@ class TFMBartEncoderLayer(tf.keras.layers.Layer):
|
|||
hidden_states, self_attn_weights, _ = self.self_attn(
|
||||
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
|
||||
)
|
||||
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(hidden_states),
|
||||
shape_list(residual),
|
||||
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
|
||||
)
|
||||
|
||||
hidden_states = self.dropout(hidden_states, training=training)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
|
@ -731,12 +746,15 @@ class TFMBartEncoder(tf.keras.layers.Layer):
|
|||
all_attentions = () if inputs["output_attentions"] else None
|
||||
|
||||
# check if head_mask has a correct number of layers specified if desired
|
||||
if inputs["head_mask"] is not None:
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if inputs["head_mask"] is not None and tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(inputs["head_mask"])[0],
|
||||
len(self.layers),
|
||||
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
|
||||
)
|
||||
|
||||
# encoder layers
|
||||
for idx, encoder_layer in enumerate(self.layers):
|
||||
|
||||
|
@ -956,12 +974,15 @@ class TFMBartDecoder(tf.keras.layers.Layer):
|
|||
present_key_values = ()
|
||||
|
||||
# check if head_mask has a correct number of layers specified if desired
|
||||
if inputs["head_mask"] is not None:
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if inputs["head_mask"] is not None and tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(inputs["head_mask"])[0],
|
||||
len(self.layers),
|
||||
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
|
||||
)
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||
if inputs["output_hidden_states"]:
|
||||
|
|
|
@ -62,8 +62,7 @@ LARGE_NEGATIVE = -1e8
|
|||
|
||||
# Copied from transformers.models.bart.modeling_tf_bart.shift_tokens_right
|
||||
def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int):
|
||||
shifted_input_ids = tf.cast(input_ids, tf.int32)
|
||||
shifted_input_ids = tf.roll(shifted_input_ids, 1, axis=-1)
|
||||
shifted_input_ids = tf.roll(input_ids, 1, axis=-1)
|
||||
start_tokens = tf.fill((shape_list(shifted_input_ids)[0], 1), decoder_start_token_id)
|
||||
shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, 1:]], -1)
|
||||
# replace possible -100 values in labels by `pad_token_id`
|
||||
|
@ -71,8 +70,9 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
|
|||
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
|
||||
)
|
||||
|
||||
if tf.executing_eagerly():
|
||||
# "Verify that `labels` has only positive values and -100"
|
||||
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.cast(0, tf.int32))
|
||||
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0))
|
||||
|
||||
# Make sure the assertion op is called by wrapping the result in an identity no-op
|
||||
with tf.control_dependencies([assert_gte0]):
|
||||
|
@ -87,14 +87,13 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i
|
|||
Make causal mask used for bi-directional self-attention.
|
||||
"""
|
||||
bsz, tgt_len = input_ids_shape
|
||||
mask = tf.ones((tgt_len, tgt_len), dtype=tf.float32) * LARGE_NEGATIVE
|
||||
mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE
|
||||
mask_cond = tf.range(shape_list(mask)[-1])
|
||||
|
||||
mask = tf.where(mask_cond < tf.reshape(mask_cond + 1, (shape_list(mask)[-1], 1)), 0.0, mask)
|
||||
mask = tf.cast(mask, tf.float32)
|
||||
|
||||
if past_key_values_length > 0:
|
||||
mask = tf.concat([tf.zeros((tgt_len, past_key_values_length), dtype=tf.float32), mask], axis=-1)
|
||||
mask = tf.concat([tf.zeros((tgt_len, past_key_values_length)), mask], axis=-1)
|
||||
|
||||
return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1))
|
||||
|
||||
|
@ -106,33 +105,43 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values
|
|||
"""
|
||||
src_len = shape_list(mask)[1]
|
||||
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||
expanded_mask = tf.cast(tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)), tf.float32)
|
||||
one_cst = tf.constant(1.0)
|
||||
mask = tf.cast(mask, dtype=one_cst.dtype)
|
||||
expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1))
|
||||
|
||||
return (1.0 - expanded_mask) * LARGE_NEGATIVE
|
||||
return (one_cst - expanded_mask) * LARGE_NEGATIVE
|
||||
|
||||
|
||||
# Copied from transformers.models.marian.modeling_tf_marian.TFMarianSinusoidalPositionalEmbedding with Marian->Pegasus
|
||||
class TFPegasusSinusoidalPositionalEmbedding(tf.keras.layers.Embedding):
|
||||
class TFPegasusSinusoidalPositionalEmbedding(tf.keras.layers.Layer):
|
||||
"""This module produces sinusoidal positional embeddings of any length."""
|
||||
|
||||
def __init__(self, num_positions: int, embedding_dim: int, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if embedding_dim % 2 != 0:
|
||||
raise NotImplementedError(f"odd embedding_dim {embedding_dim} not supported")
|
||||
super().__init__(
|
||||
num_positions,
|
||||
embedding_dim,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.embedding_dim = embedding_dim
|
||||
self.num_positions = num_positions
|
||||
|
||||
def build(self, input_shape: tf.TensorShape):
|
||||
"""
|
||||
Build shared token embedding layer Shared weights logic adapted from
|
||||
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
|
||||
"""
|
||||
super().build(input_shape) # Instantiates self.weight so it can be loaded
|
||||
weight: np.ndarray = self._init_weight(self.input_dim, self.output_dim)
|
||||
self.set_weights([weight]) # overwrite self.weight to correct value
|
||||
|
||||
weight = self._init_weight(self.num_positions, self.embedding_dim)
|
||||
|
||||
self.weight = self.add_weight(
|
||||
name="embeddings",
|
||||
shape=[self.num_positions, self.embedding_dim],
|
||||
)
|
||||
weight = tf.cast(weight, dtype=self.weight.dtype)
|
||||
|
||||
self.weight.assign(weight)
|
||||
|
||||
super().build(input_shape)
|
||||
|
||||
@staticmethod
|
||||
def _init_weight(n_pos: int, dim: int):
|
||||
|
@ -147,7 +156,7 @@ class TFPegasusSinusoidalPositionalEmbedding(tf.keras.layers.Embedding):
|
|||
position_enc[:, 0 : dim // 2] = np.sin(position_enc[:, 0::2])
|
||||
position_enc[:, dim // 2 :] = np.cos(position_enc[:, 1::2])
|
||||
# convert to tensor
|
||||
table = tf.convert_to_tensor(position_enc, dtype=tf.float32)
|
||||
table = tf.convert_to_tensor(position_enc)
|
||||
tf.stop_gradient(table)
|
||||
return table
|
||||
|
||||
|
@ -155,10 +164,8 @@ class TFPegasusSinusoidalPositionalEmbedding(tf.keras.layers.Embedding):
|
|||
"""Input is expected to be of size [bsz x seqlen]."""
|
||||
bsz, seq_len = input_shape[:2]
|
||||
|
||||
positions = tf.range(
|
||||
past_key_values_length, seq_len + past_key_values_length, delta=1, dtype=tf.int32, name="range"
|
||||
)
|
||||
return super().call(positions)
|
||||
positions = tf.range(past_key_values_length, seq_len + past_key_values_length, delta=1, name="range")
|
||||
return tf.gather(self.weight, positions)
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention with Bart->Pegasus
|
||||
|
@ -248,6 +255,9 @@ class TFPegasusAttention(tf.keras.layers.Layer):
|
|||
src_len = shape_list(key_states)[1]
|
||||
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
||||
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_weights),
|
||||
[bsz * self.num_heads, tgt_len, src_len],
|
||||
|
@ -255,31 +265,42 @@ class TFPegasusAttention(tf.keras.layers.Layer):
|
|||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attention_mask),
|
||||
[bsz, 1, tgt_len, src_len],
|
||||
message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}",
|
||||
)
|
||||
|
||||
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
|
||||
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
|
||||
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
|
||||
|
||||
attn_weights = tf.nn.softmax(attn_weights, axis=-1)
|
||||
|
||||
if layer_head_mask is not None:
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(layer_head_mask),
|
||||
[self.num_heads],
|
||||
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
|
||||
)
|
||||
|
||||
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
|
||||
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
|
||||
)
|
||||
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
|
||||
|
||||
attn_probs = self.dropout(attn_weights, training=training)
|
||||
|
||||
attn_output = tf.matmul(attn_probs, value_states)
|
||||
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_output),
|
||||
[bsz * self.num_heads, tgt_len, self.head_dim],
|
||||
|
@ -327,11 +348,16 @@ class TFPegasusEncoderLayer(tf.keras.layers.Layer):
|
|||
hidden_states, self_attn_weights, _ = self.self_attn(
|
||||
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
|
||||
)
|
||||
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(hidden_states),
|
||||
shape_list(residual),
|
||||
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
|
||||
)
|
||||
|
||||
hidden_states = self.dropout(hidden_states, training=training)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
|
@ -750,12 +776,15 @@ class TFPegasusEncoder(tf.keras.layers.Layer):
|
|||
all_attentions = () if inputs["output_attentions"] else None
|
||||
|
||||
# check if head_mask has a correct number of layers specified if desired
|
||||
if inputs["head_mask"] is not None:
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if inputs["head_mask"] is not None and tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(inputs["head_mask"])[0],
|
||||
len(self.layers),
|
||||
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
|
||||
)
|
||||
|
||||
# encoder layers
|
||||
for idx, encoder_layer in enumerate(self.layers):
|
||||
|
||||
|
@ -972,12 +1001,15 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
|
|||
present_key_values = ()
|
||||
|
||||
# check if head_mask has a correct number of layers specified if desired
|
||||
if inputs["head_mask"] is not None:
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if inputs["head_mask"] is not None and tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(inputs["head_mask"])[0],
|
||||
len(self.layers),
|
||||
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
|
||||
)
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||
if inputs["output_hidden_states"]:
|
||||
|
|
|
@ -1512,8 +1512,7 @@ LARGE_NEGATIVE = -1e8
|
|||
|
||||
|
||||
def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int):
|
||||
shifted_input_ids = tf.cast(input_ids, tf.int32)
|
||||
shifted_input_ids = tf.roll(shifted_input_ids, 1, axis=-1)
|
||||
shifted_input_ids = tf.roll(input_ids, 1, axis=-1)
|
||||
start_tokens = tf.fill((shape_list(shifted_input_ids)[0], 1), decoder_start_token_id)
|
||||
shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, 1:]], -1)
|
||||
# replace possible -100 values in labels by `pad_token_id`
|
||||
|
@ -1521,8 +1520,9 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
|
|||
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
|
||||
)
|
||||
|
||||
if tf.executing_eagerly():
|
||||
# "Verify that `labels` has only positive values and -100"
|
||||
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.cast(0, tf.int32))
|
||||
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0))
|
||||
|
||||
# Make sure the assertion op is called by wrapping the result in an identity no-op
|
||||
with tf.control_dependencies([assert_gte0]):
|
||||
|
@ -1536,14 +1536,13 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i
|
|||
Make causal mask used for bi-directional self-attention.
|
||||
"""
|
||||
bsz, tgt_len = input_ids_shape
|
||||
mask = tf.ones((tgt_len, tgt_len), dtype=tf.float32) * LARGE_NEGATIVE
|
||||
mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE
|
||||
mask_cond = tf.range(shape_list(mask)[-1])
|
||||
|
||||
mask = tf.where(mask_cond < tf.reshape(mask_cond + 1, (shape_list(mask)[-1], 1)), 0.0, mask)
|
||||
mask = tf.cast(mask, tf.float32)
|
||||
|
||||
if past_key_values_length > 0:
|
||||
mask = tf.concat([tf.zeros((tgt_len, past_key_values_length), dtype=tf.float32), mask], axis=-1)
|
||||
mask = tf.concat([tf.zeros((tgt_len, past_key_values_length)), mask], axis=-1)
|
||||
|
||||
return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1))
|
||||
|
||||
|
@ -1554,9 +1553,11 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values
|
|||
"""
|
||||
src_len = shape_list(mask)[1]
|
||||
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||
expanded_mask = tf.cast(tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)), tf.float32)
|
||||
one_cst = tf.constant(1.0)
|
||||
mask = tf.cast(mask, dtype=one_cst.dtype)
|
||||
expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1))
|
||||
|
||||
return (1.0 - expanded_mask) * LARGE_NEGATIVE
|
||||
return (one_cst - expanded_mask) * LARGE_NEGATIVE
|
||||
|
||||
|
||||
class TF{{cookiecutter.camelcase_modelname}}LearnedPositionalEmbedding(TFSharedEmbeddings):
|
||||
|
@ -1573,7 +1574,7 @@ class TF{{cookiecutter.camelcase_modelname}}LearnedPositionalEmbedding(TFSharedE
|
|||
bsz, seq_len = input_shape[:2]
|
||||
|
||||
positions = tf.range(
|
||||
past_key_values_length, seq_len + past_key_values_length, delta=1, dtype=tf.int32, name="range"
|
||||
past_key_values_length, seq_len + past_key_values_length, delta=1, name="range"
|
||||
)
|
||||
return super().call(positions)
|
||||
|
||||
|
@ -1663,6 +1664,9 @@ class TF{{cookiecutter.camelcase_modelname}}Attention(tf.keras.layers.Layer):
|
|||
src_len = shape_list(key_states)[1]
|
||||
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
|
||||
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_weights),
|
||||
[bsz * self.num_heads, tgt_len, src_len],
|
||||
|
@ -1670,11 +1674,15 @@ class TF{{cookiecutter.camelcase_modelname}}Attention(tf.keras.layers.Layer):
|
|||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attention_mask),
|
||||
[bsz, 1, tgt_len, src_len],
|
||||
message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}",
|
||||
)
|
||||
|
||||
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
|
||||
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
|
||||
|
||||
|
@ -1684,6 +1692,9 @@ class TF{{cookiecutter.camelcase_modelname}}Attention(tf.keras.layers.Layer):
|
|||
|
||||
attn_output = tf.matmul(attn_probs, value_states)
|
||||
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(attn_output),
|
||||
[bsz * self.num_heads, tgt_len, self.head_dim],
|
||||
|
@ -1727,11 +1738,16 @@ class TF{{cookiecutter.camelcase_modelname}}EncoderLayer(tf.keras.layers.Layer):
|
|||
hidden_states, self_attn_weights, _ = self.self_attn(
|
||||
hidden_states=hidden_states, attention_mask=attention_mask
|
||||
)
|
||||
|
||||
# The tf.debugging asserts are not compliant with XLA then they
|
||||
# have to be disabled in other modes than eager.
|
||||
if tf.executing_eagerly():
|
||||
tf.debugging.assert_equal(
|
||||
shape_list(hidden_states),
|
||||
shape_list(residual),
|
||||
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
|
||||
)
|
||||
|
||||
hidden_states = self.dropout(hidden_states, training=training)
|
||||
hidden_states = residual + hidden_states
|
||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||
|
@ -2352,7 +2368,7 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
|
|||
axis=-1,
|
||||
)
|
||||
else:
|
||||
attention_mask = tf.ones((input_shape[0], input_shape[1] + past_key_values_length), dtype=tf.int32)
|
||||
attention_mask = tf.ones((input_shape[0], input_shape[1] + past_key_values_length))
|
||||
|
||||
return attention_mask, combined_attention_mask
|
||||
|
||||
|
|
|
@ -279,14 +279,6 @@ class TFBartModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||
# This test is too long (>30sec) and makes fail the CI
|
||||
pass
|
||||
|
||||
def test_mixed_precision(self):
|
||||
# TODO JP: Make BART float16 compliant
|
||||
pass
|
||||
|
||||
def test_xla_mode(self):
|
||||
# TODO JP: Make BART XLA compliant
|
||||
pass
|
||||
|
||||
|
||||
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
|
||||
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
|
||||
|
|
|
@ -214,14 +214,6 @@ class TFBlenderbotModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||
# This test is too long (>30sec) and makes fail the CI
|
||||
pass
|
||||
|
||||
def test_mixed_precision(self):
|
||||
# TODO JP: Make Blenderbot float16 compliant
|
||||
pass
|
||||
|
||||
def test_xla_mode(self):
|
||||
# TODO JP: Make Blenderbot XLA compliant
|
||||
pass
|
||||
|
||||
def test_resize_token_embeddings(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
|
|
|
@ -279,14 +279,6 @@ class TFBlenderbotSmallModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||
# This test is too long (>30sec) and makes fail the CI
|
||||
pass
|
||||
|
||||
def test_mixed_precision(self):
|
||||
# TODO JP: Make Blenderbot Small float16 compliant
|
||||
pass
|
||||
|
||||
def test_xla_mode(self):
|
||||
# TODO JP: Make Blenderbot Small XLA compliant
|
||||
pass
|
||||
|
||||
|
||||
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
|
||||
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
|
||||
|
|
|
@ -247,14 +247,6 @@ class TFMarianModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||
# This test is too long (>30sec) and makes fail the CI
|
||||
pass
|
||||
|
||||
def test_mixed_precision(self):
|
||||
# TODO JP: Make Marian float16 compliant
|
||||
pass
|
||||
|
||||
def test_xla_mode(self):
|
||||
# TODO JP: Make Marian XLA compliant
|
||||
pass
|
||||
|
||||
def test_resize_token_embeddings(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
|
|
|
@ -214,18 +214,6 @@ class TFMBartModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||
name = model.get_bias()
|
||||
assert name is None
|
||||
|
||||
def test_saved_model_creation(self):
|
||||
# This test is too long (>30sec) and makes fail the CI
|
||||
pass
|
||||
|
||||
def test_mixed_precision(self):
|
||||
# TODO JP: Make MBart float16 compliant
|
||||
pass
|
||||
|
||||
def test_xla_mode(self):
|
||||
# TODO JP: Make MBart XLA compliant
|
||||
pass
|
||||
|
||||
def test_resize_token_embeddings(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
|
@ -289,6 +277,10 @@ class TFMBartModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||
models_equal = False
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
def test_saved_model_creation(self):
|
||||
# This test is too long (>30sec) and makes fail the CI
|
||||
pass
|
||||
|
||||
|
||||
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
|
||||
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
|
||||
|
|
|
@ -245,14 +245,6 @@ class TFPegasusModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||
# This test is too long (>30sec) and makes fail the CI
|
||||
pass
|
||||
|
||||
def test_mixed_precision(self):
|
||||
# TODO JP: Make Pegasus float16 compliant
|
||||
pass
|
||||
|
||||
def test_xla_mode(self):
|
||||
# TODO JP: Make Pegasus XLA compliant
|
||||
pass
|
||||
|
||||
def test_resize_token_embeddings(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
|
|
Loading…
Reference in New Issue