Making TF BART-like models XLA and AMP compliant (#10191)

* Update BART

* Update Blenderbot

* Update BlenderbotSmall

* Update Marian

* Update MBart

* Update MBart

* Update Pegasus

* Update template

* Fix Marian and Pegasus

* Apply style

* Default initializer

* Default initializer

* Default initializer

* Remove int32 casts

* Fix template

* Remove more cast
This commit is contained in:
Julien Plu 2021-02-17 17:48:56 +01:00 committed by GitHub
parent 8d79e5ca49
commit 83d803ba02
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 492 additions and 367 deletions

View File

@ -60,8 +60,7 @@ LARGE_NEGATIVE = -1e8
def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int):
shifted_input_ids = tf.cast(input_ids, tf.int32)
shifted_input_ids = tf.roll(shifted_input_ids, 1, axis=-1)
shifted_input_ids = tf.roll(input_ids, 1, axis=-1)
start_tokens = tf.fill((shape_list(shifted_input_ids)[0], 1), decoder_start_token_id)
shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, 1:]], -1)
# replace possible -100 values in labels by `pad_token_id`
@ -69,12 +68,13 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
)
# "Verify that `labels` has only positive values and -100"
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.cast(0, tf.int32))
if tf.executing_eagerly():
# "Verify that `labels` has only positive values and -100"
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0))
# Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]):
shifted_input_ids = tf.identity(shifted_input_ids)
# Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]):
shifted_input_ids = tf.identity(shifted_input_ids)
return shifted_input_ids
@ -84,14 +84,13 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = tf.ones((tgt_len, tgt_len), dtype=tf.float32) * LARGE_NEGATIVE
mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE
mask_cond = tf.range(shape_list(mask)[-1])
mask = tf.where(mask_cond < tf.reshape(mask_cond + 1, (shape_list(mask)[-1], 1)), 0.0, mask)
mask = tf.cast(mask, tf.float32)
if past_key_values_length > 0:
mask = tf.concat([tf.zeros((tgt_len, past_key_values_length), dtype=tf.float32), mask], axis=-1)
mask = tf.concat([tf.zeros((tgt_len, past_key_values_length)), mask], axis=-1)
return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1))
@ -102,9 +101,11 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values
"""
src_len = shape_list(mask)[1]
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = tf.cast(tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)), tf.float32)
one_cst = tf.constant(1.0)
mask = tf.cast(mask, dtype=one_cst.dtype)
expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1))
return (1.0 - expanded_mask) * LARGE_NEGATIVE
return (one_cst - expanded_mask) * LARGE_NEGATIVE
class TFBartLearnedPositionalEmbedding(TFSharedEmbeddings):
@ -123,9 +124,7 @@ class TFBartLearnedPositionalEmbedding(TFSharedEmbeddings):
"""Input is expected to be of size [bsz x seqlen]."""
bsz, seq_len = input_shape[:2]
positions = tf.range(
past_key_values_length, seq_len + past_key_values_length, delta=1, dtype=tf.int32, name="range"
)
positions = tf.range(past_key_values_length, seq_len + past_key_values_length, delta=1, name="range")
return super().call(positions + self.offset)
@ -215,43 +214,57 @@ class TFBartAttention(tf.keras.layers.Layer):
src_len = shape_list(key_states)[1]
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
tf.debugging.assert_equal(
shape_list(attn_weights),
[bsz * self.num_heads, tgt_len, src_len],
message=f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {shape_list(attn_weights)}",
)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attn_weights),
[bsz * self.num_heads, tgt_len, src_len],
message=f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {shape_list(attn_weights)}",
)
if attention_mask is not None:
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}",
)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}",
)
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
attn_weights = tf.nn.softmax(attn_weights, axis=-1)
if layer_head_mask is not None:
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
)
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
)
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
attn_probs = self.dropout(attn_weights, training=training)
attn_output = tf.matmul(attn_probs, value_states)
tf.debugging.assert_equal(
shape_list(attn_output),
[bsz * self.num_heads, tgt_len, self.head_dim],
message=f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {shape_list(attn_output)}",
)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attn_output),
[bsz * self.num_heads, tgt_len, self.head_dim],
message=f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {shape_list(attn_output)}",
)
attn_output = tf.transpose(
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
@ -292,11 +305,16 @@ class TFBartEncoderLayer(tf.keras.layers.Layer):
hidden_states, self_attn_weights, _ = self.self_attn(
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
)
tf.debugging.assert_equal(
shape_list(hidden_states),
shape_list(residual),
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(hidden_states),
shape_list(residual),
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
)
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = residual + hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
@ -717,12 +735,15 @@ class TFBartEncoder(tf.keras.layers.Layer):
all_attentions = () if inputs["output_attentions"] else None
# check if head_mask has a correct number of layers specified if desired
if inputs["head_mask"] is not None:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if inputs["head_mask"] is not None and tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(inputs["head_mask"])[0],
len(self.layers),
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
)
# encoder layers
for idx, encoder_layer in enumerate(self.layers):
@ -933,12 +954,15 @@ class TFBartDecoder(tf.keras.layers.Layer):
present_key_values = () if inputs["use_cache"] else None
# check if head_mask has a correct number of layers specified if desired
if inputs["head_mask"] is not None:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if inputs["head_mask"] is not None and tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(inputs["head_mask"])[0],
len(self.layers),
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
)
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if inputs["output_hidden_states"]:

View File

@ -63,8 +63,7 @@ LARGE_NEGATIVE = -1e8
# Copied from transformers.models.bart.modeling_tf_bart.shift_tokens_right
def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int):
shifted_input_ids = tf.cast(input_ids, tf.int32)
shifted_input_ids = tf.roll(shifted_input_ids, 1, axis=-1)
shifted_input_ids = tf.roll(input_ids, 1, axis=-1)
start_tokens = tf.fill((shape_list(shifted_input_ids)[0], 1), decoder_start_token_id)
shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, 1:]], -1)
# replace possible -100 values in labels by `pad_token_id`
@ -72,12 +71,13 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
)
# "Verify that `labels` has only positive values and -100"
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.cast(0, tf.int32))
if tf.executing_eagerly():
# "Verify that `labels` has only positive values and -100"
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0))
# Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]):
shifted_input_ids = tf.identity(shifted_input_ids)
# Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]):
shifted_input_ids = tf.identity(shifted_input_ids)
return shifted_input_ids
@ -88,14 +88,13 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = tf.ones((tgt_len, tgt_len), dtype=tf.float32) * LARGE_NEGATIVE
mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE
mask_cond = tf.range(shape_list(mask)[-1])
mask = tf.where(mask_cond < tf.reshape(mask_cond + 1, (shape_list(mask)[-1], 1)), 0.0, mask)
mask = tf.cast(mask, tf.float32)
if past_key_values_length > 0:
mask = tf.concat([tf.zeros((tgt_len, past_key_values_length), dtype=tf.float32), mask], axis=-1)
mask = tf.concat([tf.zeros((tgt_len, past_key_values_length)), mask], axis=-1)
return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1))
@ -107,9 +106,11 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values
"""
src_len = shape_list(mask)[1]
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = tf.cast(tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)), tf.float32)
one_cst = tf.constant(1.0)
mask = tf.cast(mask, dtype=one_cst.dtype)
expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1))
return (1.0 - expanded_mask) * LARGE_NEGATIVE
return (one_cst - expanded_mask) * LARGE_NEGATIVE
class TFBlenderbotLearnedPositionalEmbedding(TFSharedEmbeddings):
@ -125,9 +126,7 @@ class TFBlenderbotLearnedPositionalEmbedding(TFSharedEmbeddings):
"""Input is expected to be of size [bsz x seqlen]."""
bsz, seq_len = input_shape[:2]
positions = tf.range(
past_key_values_length, seq_len + past_key_values_length, delta=1, dtype=tf.int32, name="range"
)
positions = tf.range(past_key_values_length, seq_len + past_key_values_length, delta=1, name="range")
return super().call(positions)
@ -218,43 +217,57 @@ class TFBlenderbotAttention(tf.keras.layers.Layer):
src_len = shape_list(key_states)[1]
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
tf.debugging.assert_equal(
shape_list(attn_weights),
[bsz * self.num_heads, tgt_len, src_len],
message=f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {shape_list(attn_weights)}",
)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attn_weights),
[bsz * self.num_heads, tgt_len, src_len],
message=f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {shape_list(attn_weights)}",
)
if attention_mask is not None:
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}",
)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}",
)
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
attn_weights = tf.nn.softmax(attn_weights, axis=-1)
if layer_head_mask is not None:
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
)
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
)
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
attn_probs = self.dropout(attn_weights, training=training)
attn_output = tf.matmul(attn_probs, value_states)
tf.debugging.assert_equal(
shape_list(attn_output),
[bsz * self.num_heads, tgt_len, self.head_dim],
message=f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {shape_list(attn_output)}",
)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attn_output),
[bsz * self.num_heads, tgt_len, self.head_dim],
message=f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {shape_list(attn_output)}",
)
attn_output = tf.transpose(
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
@ -297,11 +310,16 @@ class TFBlenderbotEncoderLayer(tf.keras.layers.Layer):
hidden_states, self_attn_weights, _ = self.self_attn(
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
)
tf.debugging.assert_equal(
shape_list(hidden_states),
shape_list(residual),
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(hidden_states),
shape_list(residual),
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
)
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = residual + hidden_states
@ -719,12 +737,15 @@ class TFBlenderbotEncoder(tf.keras.layers.Layer):
all_attentions = () if inputs["output_attentions"] else None
# check if head_mask has a correct number of layers specified if desired
if inputs["head_mask"] is not None:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if inputs["head_mask"] is not None and tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(inputs["head_mask"])[0],
len(self.layers),
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
)
# encoder layers
for idx, encoder_layer in enumerate(self.layers):
@ -943,12 +964,15 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer):
present_key_values = ()
# check if head_mask has a correct number of layers specified if desired
if inputs["head_mask"] is not None:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if inputs["head_mask"] is not None and tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(inputs["head_mask"])[0],
len(self.layers),
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
)
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if inputs["output_hidden_states"]:

View File

@ -61,8 +61,7 @@ LARGE_NEGATIVE = -1e8
# Copied from transformers.models.bart.modeling_tf_bart.shift_tokens_right
def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int):
shifted_input_ids = tf.cast(input_ids, tf.int32)
shifted_input_ids = tf.roll(shifted_input_ids, 1, axis=-1)
shifted_input_ids = tf.roll(input_ids, 1, axis=-1)
start_tokens = tf.fill((shape_list(shifted_input_ids)[0], 1), decoder_start_token_id)
shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, 1:]], -1)
# replace possible -100 values in labels by `pad_token_id`
@ -70,12 +69,13 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
)
# "Verify that `labels` has only positive values and -100"
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.cast(0, tf.int32))
if tf.executing_eagerly():
# "Verify that `labels` has only positive values and -100"
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0))
# Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]):
shifted_input_ids = tf.identity(shifted_input_ids)
# Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]):
shifted_input_ids = tf.identity(shifted_input_ids)
return shifted_input_ids
@ -86,14 +86,13 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = tf.ones((tgt_len, tgt_len), dtype=tf.float32) * LARGE_NEGATIVE
mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE
mask_cond = tf.range(shape_list(mask)[-1])
mask = tf.where(mask_cond < tf.reshape(mask_cond + 1, (shape_list(mask)[-1], 1)), 0.0, mask)
mask = tf.cast(mask, tf.float32)
if past_key_values_length > 0:
mask = tf.concat([tf.zeros((tgt_len, past_key_values_length), dtype=tf.float32), mask], axis=-1)
mask = tf.concat([tf.zeros((tgt_len, past_key_values_length)), mask], axis=-1)
return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1))
@ -105,9 +104,11 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values
"""
src_len = shape_list(mask)[1]
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = tf.cast(tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)), tf.float32)
one_cst = tf.constant(1.0)
mask = tf.cast(mask, dtype=one_cst.dtype)
expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1))
return (1.0 - expanded_mask) * LARGE_NEGATIVE
return (one_cst - expanded_mask) * LARGE_NEGATIVE
# Copied from transformers.models.blenderbot.modeling_tf_blenderbot.TFBlenderbotLearnedPositionalEmbedding with Blenderbot->BlenderbotSmall
@ -124,9 +125,7 @@ class TFBlenderbotSmallLearnedPositionalEmbedding(TFSharedEmbeddings):
"""Input is expected to be of size [bsz x seqlen]."""
bsz, seq_len = input_shape[:2]
positions = tf.range(
past_key_values_length, seq_len + past_key_values_length, delta=1, dtype=tf.int32, name="range"
)
positions = tf.range(past_key_values_length, seq_len + past_key_values_length, delta=1, name="range")
return super().call(positions)
@ -217,43 +216,57 @@ class TFBlenderbotSmallAttention(tf.keras.layers.Layer):
src_len = shape_list(key_states)[1]
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
tf.debugging.assert_equal(
shape_list(attn_weights),
[bsz * self.num_heads, tgt_len, src_len],
message=f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {shape_list(attn_weights)}",
)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attn_weights),
[bsz * self.num_heads, tgt_len, src_len],
message=f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {shape_list(attn_weights)}",
)
if attention_mask is not None:
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}",
)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}",
)
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
attn_weights = tf.nn.softmax(attn_weights, axis=-1)
if layer_head_mask is not None:
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
)
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
)
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
attn_probs = self.dropout(attn_weights, training=training)
attn_output = tf.matmul(attn_probs, value_states)
tf.debugging.assert_equal(
shape_list(attn_output),
[bsz * self.num_heads, tgt_len, self.head_dim],
message=f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {shape_list(attn_output)}",
)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attn_output),
[bsz * self.num_heads, tgt_len, self.head_dim],
message=f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {shape_list(attn_output)}",
)
attn_output = tf.transpose(
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
@ -295,11 +308,16 @@ class TFBlenderbotSmallEncoderLayer(tf.keras.layers.Layer):
hidden_states, self_attn_weights, _ = self.self_attn(
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
)
tf.debugging.assert_equal(
shape_list(hidden_states),
shape_list(residual),
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(hidden_states),
shape_list(residual),
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
)
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = residual + hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
@ -725,12 +743,15 @@ class TFBlenderbotSmallEncoder(tf.keras.layers.Layer):
all_attentions = () if inputs["output_attentions"] else None
# check if head_mask has a correct number of layers specified if desired
if inputs["head_mask"] is not None:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if inputs["head_mask"] is not None and tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(inputs["head_mask"])[0],
len(self.layers),
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
)
# encoder layers
for idx, encoder_layer in enumerate(self.layers):
@ -946,12 +967,15 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
present_key_values = ()
# check if head_mask has a correct number of layers specified if desired
if inputs["head_mask"] is not None:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if inputs["head_mask"] is not None and tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(inputs["head_mask"])[0],
len(self.layers),
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
)
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if inputs["output_hidden_states"]:

View File

@ -62,8 +62,7 @@ LARGE_NEGATIVE = -1e8
# Copied from transformers.models.bart.modeling_tf_bart.shift_tokens_right
def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int):
shifted_input_ids = tf.cast(input_ids, tf.int32)
shifted_input_ids = tf.roll(shifted_input_ids, 1, axis=-1)
shifted_input_ids = tf.roll(input_ids, 1, axis=-1)
start_tokens = tf.fill((shape_list(shifted_input_ids)[0], 1), decoder_start_token_id)
shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, 1:]], -1)
# replace possible -100 values in labels by `pad_token_id`
@ -71,12 +70,13 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
)
# "Verify that `labels` has only positive values and -100"
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.cast(0, tf.int32))
if tf.executing_eagerly():
# "Verify that `labels` has only positive values and -100"
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0))
# Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]):
shifted_input_ids = tf.identity(shifted_input_ids)
# Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]):
shifted_input_ids = tf.identity(shifted_input_ids)
return shifted_input_ids
@ -87,14 +87,13 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = tf.ones((tgt_len, tgt_len), dtype=tf.float32) * LARGE_NEGATIVE
mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE
mask_cond = tf.range(shape_list(mask)[-1])
mask = tf.where(mask_cond < tf.reshape(mask_cond + 1, (shape_list(mask)[-1], 1)), 0.0, mask)
mask = tf.cast(mask, tf.float32)
if past_key_values_length > 0:
mask = tf.concat([tf.zeros((tgt_len, past_key_values_length), dtype=tf.float32), mask], axis=-1)
mask = tf.concat([tf.zeros((tgt_len, past_key_values_length)), mask], axis=-1)
return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1))
@ -106,32 +105,42 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values
"""
src_len = shape_list(mask)[1]
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = tf.cast(tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)), tf.float32)
one_cst = tf.constant(1.0)
mask = tf.cast(mask, dtype=one_cst.dtype)
expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1))
return (1.0 - expanded_mask) * LARGE_NEGATIVE
return (one_cst - expanded_mask) * LARGE_NEGATIVE
class TFMarianSinusoidalPositionalEmbedding(tf.keras.layers.Embedding):
class TFMarianSinusoidalPositionalEmbedding(tf.keras.layers.Layer):
"""This module produces sinusoidal positional embeddings of any length."""
def __init__(self, num_positions: int, embedding_dim: int, **kwargs):
super().__init__(**kwargs)
if embedding_dim % 2 != 0:
raise NotImplementedError(f"odd embedding_dim {embedding_dim} not supported")
super().__init__(
num_positions,
embedding_dim,
**kwargs,
)
self.embedding_dim = embedding_dim
self.num_positions = num_positions
def build(self, input_shape: tf.TensorShape):
"""
Build shared token embedding layer Shared weights logic adapted from
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
"""
super().build(input_shape) # Instantiates self.weight so it can be loaded
weight: np.ndarray = self._init_weight(self.input_dim, self.output_dim)
self.set_weights([weight]) # overwrite self.weight to correct value
weight = self._init_weight(self.num_positions, self.embedding_dim)
self.weight = self.add_weight(
name="embeddings",
shape=[self.num_positions, self.embedding_dim],
)
weight = tf.cast(weight, dtype=self.weight.dtype)
self.weight.assign(weight)
super().build(input_shape)
@staticmethod
def _init_weight(n_pos: int, dim: int):
@ -146,7 +155,7 @@ class TFMarianSinusoidalPositionalEmbedding(tf.keras.layers.Embedding):
position_enc[:, 0 : dim // 2] = np.sin(position_enc[:, 0::2])
position_enc[:, dim // 2 :] = np.cos(position_enc[:, 1::2])
# convert to tensor
table = tf.convert_to_tensor(position_enc, dtype=tf.float32)
table = tf.convert_to_tensor(position_enc)
tf.stop_gradient(table)
return table
@ -154,10 +163,8 @@ class TFMarianSinusoidalPositionalEmbedding(tf.keras.layers.Embedding):
"""Input is expected to be of size [bsz x seqlen]."""
bsz, seq_len = input_shape[:2]
positions = tf.range(
past_key_values_length, seq_len + past_key_values_length, delta=1, dtype=tf.int32, name="range"
)
return super().call(positions)
positions = tf.range(past_key_values_length, seq_len + past_key_values_length, delta=1, name="range")
return tf.gather(self.weight, positions)
# Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention with Bart->Marian
@ -247,43 +254,57 @@ class TFMarianAttention(tf.keras.layers.Layer):
src_len = shape_list(key_states)[1]
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
tf.debugging.assert_equal(
shape_list(attn_weights),
[bsz * self.num_heads, tgt_len, src_len],
message=f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {shape_list(attn_weights)}",
)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attn_weights),
[bsz * self.num_heads, tgt_len, src_len],
message=f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {shape_list(attn_weights)}",
)
if attention_mask is not None:
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}",
)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}",
)
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
attn_weights = tf.nn.softmax(attn_weights, axis=-1)
if layer_head_mask is not None:
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
)
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
)
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
attn_probs = self.dropout(attn_weights, training=training)
attn_output = tf.matmul(attn_probs, value_states)
tf.debugging.assert_equal(
shape_list(attn_output),
[bsz * self.num_heads, tgt_len, self.head_dim],
message=f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {shape_list(attn_output)}",
)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attn_output),
[bsz * self.num_heads, tgt_len, self.head_dim],
message=f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {shape_list(attn_output)}",
)
attn_output = tf.transpose(
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
@ -325,11 +346,16 @@ class TFMarianEncoderLayer(tf.keras.layers.Layer):
hidden_states, self_attn_weights, _ = self.self_attn(
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
)
tf.debugging.assert_equal(
shape_list(hidden_states),
shape_list(residual),
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(hidden_states),
shape_list(residual),
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
)
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = residual + hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
@ -741,12 +767,15 @@ class TFMarianEncoder(tf.keras.layers.Layer):
all_attentions = () if inputs["output_attentions"] else None
# check if head_mask has a correct number of layers specified if desired
if inputs["head_mask"] is not None:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if inputs["head_mask"] is not None and tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(inputs["head_mask"])[0],
len(self.layers),
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
)
# encoder layers
for idx, encoder_layer in enumerate(self.layers):
@ -960,12 +989,15 @@ class TFMarianDecoder(tf.keras.layers.Layer):
present_key_values = ()
# check if head_mask has a correct number of layers specified if desired
if inputs["head_mask"] is not None:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if inputs["head_mask"] is not None and tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(inputs["head_mask"])[0],
len(self.layers),
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
)
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if inputs["output_hidden_states"]:

View File

@ -64,19 +64,16 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int):
Shift input ids one token to the right, and wrap the last non pad token (the <LID> token) Note that MBart does not
have a single `decoder_start_token_id` in contrast to other Bart-like models.
"""
prev_output_tokens = tf.cast(input_ids, tf.int32)
assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
# replace possible -100 values in labels by `pad_token_id`
prev_output_tokens = tf.where(
prev_output_tokens == -100, tf.fill(shape_list(prev_output_tokens), pad_token_id), prev_output_tokens
)
input_ids = tf.where(input_ids == -100, tf.fill(shape_list(input_ids), pad_token_id), input_ids)
language_id_index = (
tf.reduce_sum(tf.cast(tf.math.not_equal(prev_output_tokens, pad_token_id), tf.int32), axis=-1) - 1
tf.reduce_sum(tf.cast(tf.math.not_equal(input_ids, pad_token_id), dtype=input_ids.dtype), axis=-1) - 1
)
language_id_index = tf.stack([tf.range(shape_list(input_ids)[0]), language_id_index], axis=-1)
languages_ids = tf.gather_nd(prev_output_tokens, language_id_index)
languages_ids = tf.gather_nd(input_ids, language_id_index)
shifted_input_ids = tf.concat([tf.expand_dims(languages_ids, axis=-1), prev_output_tokens[:, :-1]], axis=-1)
shifted_input_ids = tf.concat([tf.expand_dims(languages_ids, axis=-1), input_ids[:, :-1]], axis=-1)
return shifted_input_ids
@ -87,14 +84,13 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = tf.ones((tgt_len, tgt_len), dtype=tf.float32) * LARGE_NEGATIVE
mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE
mask_cond = tf.range(shape_list(mask)[-1])
mask = tf.where(mask_cond < tf.reshape(mask_cond + 1, (shape_list(mask)[-1], 1)), 0.0, mask)
mask = tf.cast(mask, tf.float32)
if past_key_values_length > 0:
mask = tf.concat([tf.zeros((tgt_len, past_key_values_length), dtype=tf.float32), mask], axis=-1)
mask = tf.concat([tf.zeros((tgt_len, past_key_values_length)), mask], axis=-1)
return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1))
@ -106,9 +102,11 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values
"""
src_len = shape_list(mask)[1]
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = tf.cast(tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)), tf.float32)
one_cst = tf.constant(1.0)
mask = tf.cast(mask, dtype=one_cst.dtype)
expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1))
return (1.0 - expanded_mask) * LARGE_NEGATIVE
return (one_cst - expanded_mask) * LARGE_NEGATIVE
# Copied from transformers.models.bart.modeling_tf_bart.TFBartLearnedPositionalEmbedding with Bart->MBart
@ -128,9 +126,7 @@ class TFMBartLearnedPositionalEmbedding(TFSharedEmbeddings):
"""Input is expected to be of size [bsz x seqlen]."""
bsz, seq_len = input_shape[:2]
positions = tf.range(
past_key_values_length, seq_len + past_key_values_length, delta=1, dtype=tf.int32, name="range"
)
positions = tf.range(past_key_values_length, seq_len + past_key_values_length, delta=1, name="range")
return super().call(positions + self.offset)
@ -221,43 +217,57 @@ class TFMBartAttention(tf.keras.layers.Layer):
src_len = shape_list(key_states)[1]
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
tf.debugging.assert_equal(
shape_list(attn_weights),
[bsz * self.num_heads, tgt_len, src_len],
message=f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {shape_list(attn_weights)}",
)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attn_weights),
[bsz * self.num_heads, tgt_len, src_len],
message=f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {shape_list(attn_weights)}",
)
if attention_mask is not None:
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}",
)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}",
)
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
attn_weights = tf.nn.softmax(attn_weights, axis=-1)
if layer_head_mask is not None:
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
)
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
)
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
attn_probs = self.dropout(attn_weights, training=training)
attn_output = tf.matmul(attn_probs, value_states)
tf.debugging.assert_equal(
shape_list(attn_output),
[bsz * self.num_heads, tgt_len, self.head_dim],
message=f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {shape_list(attn_output)}",
)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attn_output),
[bsz * self.num_heads, tgt_len, self.head_dim],
message=f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {shape_list(attn_output)}",
)
attn_output = tf.transpose(
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
@ -299,11 +309,16 @@ class TFMBartEncoderLayer(tf.keras.layers.Layer):
hidden_states, self_attn_weights, _ = self.self_attn(
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
)
tf.debugging.assert_equal(
shape_list(hidden_states),
shape_list(residual),
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(hidden_states),
shape_list(residual),
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
)
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = residual + hidden_states
@ -731,12 +746,15 @@ class TFMBartEncoder(tf.keras.layers.Layer):
all_attentions = () if inputs["output_attentions"] else None
# check if head_mask has a correct number of layers specified if desired
if inputs["head_mask"] is not None:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if inputs["head_mask"] is not None and tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(inputs["head_mask"])[0],
len(self.layers),
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
)
# encoder layers
for idx, encoder_layer in enumerate(self.layers):
@ -956,12 +974,15 @@ class TFMBartDecoder(tf.keras.layers.Layer):
present_key_values = ()
# check if head_mask has a correct number of layers specified if desired
if inputs["head_mask"] is not None:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if inputs["head_mask"] is not None and tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(inputs["head_mask"])[0],
len(self.layers),
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
)
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if inputs["output_hidden_states"]:

View File

@ -62,8 +62,7 @@ LARGE_NEGATIVE = -1e8
# Copied from transformers.models.bart.modeling_tf_bart.shift_tokens_right
def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int):
shifted_input_ids = tf.cast(input_ids, tf.int32)
shifted_input_ids = tf.roll(shifted_input_ids, 1, axis=-1)
shifted_input_ids = tf.roll(input_ids, 1, axis=-1)
start_tokens = tf.fill((shape_list(shifted_input_ids)[0], 1), decoder_start_token_id)
shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, 1:]], -1)
# replace possible -100 values in labels by `pad_token_id`
@ -71,12 +70,13 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
)
# "Verify that `labels` has only positive values and -100"
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.cast(0, tf.int32))
if tf.executing_eagerly():
# "Verify that `labels` has only positive values and -100"
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0))
# Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]):
shifted_input_ids = tf.identity(shifted_input_ids)
# Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]):
shifted_input_ids = tf.identity(shifted_input_ids)
return shifted_input_ids
@ -87,14 +87,13 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = tf.ones((tgt_len, tgt_len), dtype=tf.float32) * LARGE_NEGATIVE
mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE
mask_cond = tf.range(shape_list(mask)[-1])
mask = tf.where(mask_cond < tf.reshape(mask_cond + 1, (shape_list(mask)[-1], 1)), 0.0, mask)
mask = tf.cast(mask, tf.float32)
if past_key_values_length > 0:
mask = tf.concat([tf.zeros((tgt_len, past_key_values_length), dtype=tf.float32), mask], axis=-1)
mask = tf.concat([tf.zeros((tgt_len, past_key_values_length)), mask], axis=-1)
return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1))
@ -106,33 +105,43 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values
"""
src_len = shape_list(mask)[1]
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = tf.cast(tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)), tf.float32)
one_cst = tf.constant(1.0)
mask = tf.cast(mask, dtype=one_cst.dtype)
expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1))
return (1.0 - expanded_mask) * LARGE_NEGATIVE
return (one_cst - expanded_mask) * LARGE_NEGATIVE
# Copied from transformers.models.marian.modeling_tf_marian.TFMarianSinusoidalPositionalEmbedding with Marian->Pegasus
class TFPegasusSinusoidalPositionalEmbedding(tf.keras.layers.Embedding):
class TFPegasusSinusoidalPositionalEmbedding(tf.keras.layers.Layer):
"""This module produces sinusoidal positional embeddings of any length."""
def __init__(self, num_positions: int, embedding_dim: int, **kwargs):
super().__init__(**kwargs)
if embedding_dim % 2 != 0:
raise NotImplementedError(f"odd embedding_dim {embedding_dim} not supported")
super().__init__(
num_positions,
embedding_dim,
**kwargs,
)
self.embedding_dim = embedding_dim
self.num_positions = num_positions
def build(self, input_shape: tf.TensorShape):
"""
Build shared token embedding layer Shared weights logic adapted from
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
"""
super().build(input_shape) # Instantiates self.weight so it can be loaded
weight: np.ndarray = self._init_weight(self.input_dim, self.output_dim)
self.set_weights([weight]) # overwrite self.weight to correct value
weight = self._init_weight(self.num_positions, self.embedding_dim)
self.weight = self.add_weight(
name="embeddings",
shape=[self.num_positions, self.embedding_dim],
)
weight = tf.cast(weight, dtype=self.weight.dtype)
self.weight.assign(weight)
super().build(input_shape)
@staticmethod
def _init_weight(n_pos: int, dim: int):
@ -147,7 +156,7 @@ class TFPegasusSinusoidalPositionalEmbedding(tf.keras.layers.Embedding):
position_enc[:, 0 : dim // 2] = np.sin(position_enc[:, 0::2])
position_enc[:, dim // 2 :] = np.cos(position_enc[:, 1::2])
# convert to tensor
table = tf.convert_to_tensor(position_enc, dtype=tf.float32)
table = tf.convert_to_tensor(position_enc)
tf.stop_gradient(table)
return table
@ -155,10 +164,8 @@ class TFPegasusSinusoidalPositionalEmbedding(tf.keras.layers.Embedding):
"""Input is expected to be of size [bsz x seqlen]."""
bsz, seq_len = input_shape[:2]
positions = tf.range(
past_key_values_length, seq_len + past_key_values_length, delta=1, dtype=tf.int32, name="range"
)
return super().call(positions)
positions = tf.range(past_key_values_length, seq_len + past_key_values_length, delta=1, name="range")
return tf.gather(self.weight, positions)
# Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention with Bart->Pegasus
@ -248,43 +255,57 @@ class TFPegasusAttention(tf.keras.layers.Layer):
src_len = shape_list(key_states)[1]
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
tf.debugging.assert_equal(
shape_list(attn_weights),
[bsz * self.num_heads, tgt_len, src_len],
message=f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {shape_list(attn_weights)}",
)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attn_weights),
[bsz * self.num_heads, tgt_len, src_len],
message=f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {shape_list(attn_weights)}",
)
if attention_mask is not None:
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}",
)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}",
)
attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
attn_weights = tf.nn.softmax(attn_weights, axis=-1)
if layer_head_mask is not None:
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(layer_head_mask),
[self.num_heads],
message=f"Head mask for a single layer should be of size {(self.num_heads)}, but is {shape_list(layer_head_mask)}",
)
attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
attn_weights, (bsz, self.num_heads, tgt_len, src_len)
)
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
attn_probs = self.dropout(attn_weights, training=training)
attn_output = tf.matmul(attn_probs, value_states)
tf.debugging.assert_equal(
shape_list(attn_output),
[bsz * self.num_heads, tgt_len, self.head_dim],
message=f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {shape_list(attn_output)}",
)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attn_output),
[bsz * self.num_heads, tgt_len, self.head_dim],
message=f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {shape_list(attn_output)}",
)
attn_output = tf.transpose(
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
@ -327,11 +348,16 @@ class TFPegasusEncoderLayer(tf.keras.layers.Layer):
hidden_states, self_attn_weights, _ = self.self_attn(
hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask
)
tf.debugging.assert_equal(
shape_list(hidden_states),
shape_list(residual),
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(hidden_states),
shape_list(residual),
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
)
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = residual + hidden_states
@ -750,12 +776,15 @@ class TFPegasusEncoder(tf.keras.layers.Layer):
all_attentions = () if inputs["output_attentions"] else None
# check if head_mask has a correct number of layers specified if desired
if inputs["head_mask"] is not None:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if inputs["head_mask"] is not None and tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(inputs["head_mask"])[0],
len(self.layers),
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
)
# encoder layers
for idx, encoder_layer in enumerate(self.layers):
@ -972,12 +1001,15 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
present_key_values = ()
# check if head_mask has a correct number of layers specified if desired
if inputs["head_mask"] is not None:
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if inputs["head_mask"] is not None and tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(inputs["head_mask"])[0],
len(self.layers),
message=f"The head_mask should be specified for {len(self.layers)} layers, but it is for {shape_list(inputs['head_mask'])[0]}.",
)
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if inputs["output_hidden_states"]:

View File

@ -1512,8 +1512,7 @@ LARGE_NEGATIVE = -1e8
def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int):
shifted_input_ids = tf.cast(input_ids, tf.int32)
shifted_input_ids = tf.roll(shifted_input_ids, 1, axis=-1)
shifted_input_ids = tf.roll(input_ids, 1, axis=-1)
start_tokens = tf.fill((shape_list(shifted_input_ids)[0], 1), decoder_start_token_id)
shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, 1:]], -1)
# replace possible -100 values in labels by `pad_token_id`
@ -1521,12 +1520,13 @@ def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_to
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
)
# "Verify that `labels` has only positive values and -100"
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.cast(0, tf.int32))
if tf.executing_eagerly():
# "Verify that `labels` has only positive values and -100"
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0))
# Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]):
shifted_input_ids = tf.identity(shifted_input_ids)
# Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]):
shifted_input_ids = tf.identity(shifted_input_ids)
return shifted_input_ids
@ -1536,15 +1536,14 @@ def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: i
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = tf.ones((tgt_len, tgt_len), dtype=tf.float32) * LARGE_NEGATIVE
mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE
mask_cond = tf.range(shape_list(mask)[-1])
mask = tf.where(mask_cond < tf.reshape(mask_cond + 1, (shape_list(mask)[-1], 1)), 0.0, mask)
mask = tf.cast(mask, tf.float32)
if past_key_values_length > 0:
mask = tf.concat([tf.zeros((tgt_len, past_key_values_length), dtype=tf.float32), mask], axis=-1)
mask = tf.concat([tf.zeros((tgt_len, past_key_values_length)), mask], axis=-1)
return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1))
@ -1554,9 +1553,11 @@ def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None, past_key_values
"""
src_len = shape_list(mask)[1]
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = tf.cast(tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)), tf.float32)
one_cst = tf.constant(1.0)
mask = tf.cast(mask, dtype=one_cst.dtype)
expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1))
return (1.0 - expanded_mask) * LARGE_NEGATIVE
return (one_cst - expanded_mask) * LARGE_NEGATIVE
class TF{{cookiecutter.camelcase_modelname}}LearnedPositionalEmbedding(TFSharedEmbeddings):
@ -1573,7 +1574,7 @@ class TF{{cookiecutter.camelcase_modelname}}LearnedPositionalEmbedding(TFSharedE
bsz, seq_len = input_shape[:2]
positions = tf.range(
past_key_values_length, seq_len + past_key_values_length, delta=1, dtype=tf.int32, name="range"
past_key_values_length, seq_len + past_key_values_length, delta=1, name="range"
)
return super().call(positions)
@ -1663,18 +1664,25 @@ class TF{{cookiecutter.camelcase_modelname}}Attention(tf.keras.layers.Layer):
src_len = shape_list(key_states)[1]
attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
tf.debugging.assert_equal(
shape_list(attn_weights),
[bsz * self.num_heads, tgt_len, src_len],
message=f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {shape_list(attn_weights)}",
)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attn_weights),
[bsz * self.num_heads, tgt_len, src_len],
message=f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {shape_list(attn_weights)}",
)
if attention_mask is not None:
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}",
)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attention_mask),
[bsz, 1, tgt_len, src_len],
message=f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {shape_list(attention_mask)}",
)
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
@ -1684,11 +1692,14 @@ class TF{{cookiecutter.camelcase_modelname}}Attention(tf.keras.layers.Layer):
attn_output = tf.matmul(attn_probs, value_states)
tf.debugging.assert_equal(
shape_list(attn_output),
[bsz * self.num_heads, tgt_len, self.head_dim],
message=f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {shape_list(attn_output)}",
)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attn_output),
[bsz * self.num_heads, tgt_len, self.head_dim],
message=f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {shape_list(attn_output)}",
)
attn_output = tf.transpose(
tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
@ -1727,11 +1738,16 @@ class TF{{cookiecutter.camelcase_modelname}}EncoderLayer(tf.keras.layers.Layer):
hidden_states, self_attn_weights, _ = self.self_attn(
hidden_states=hidden_states, attention_mask=attention_mask
)
tf.debugging.assert_equal(
shape_list(hidden_states),
shape_list(residual),
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
)
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
if tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(hidden_states),
shape_list(residual),
message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
)
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = residual + hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
@ -2352,7 +2368,7 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
axis=-1,
)
else:
attention_mask = tf.ones((input_shape[0], input_shape[1] + past_key_values_length), dtype=tf.int32)
attention_mask = tf.ones((input_shape[0], input_shape[1] + past_key_values_length))
return attention_mask, combined_attention_mask

View File

@ -279,14 +279,6 @@ class TFBartModelTest(TFModelTesterMixin, unittest.TestCase):
# This test is too long (>30sec) and makes fail the CI
pass
def test_mixed_precision(self):
# TODO JP: Make BART float16 compliant
pass
def test_xla_mode(self):
# TODO JP: Make BART XLA compliant
pass
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""

View File

@ -214,14 +214,6 @@ class TFBlenderbotModelTest(TFModelTesterMixin, unittest.TestCase):
# This test is too long (>30sec) and makes fail the CI
pass
def test_mixed_precision(self):
# TODO JP: Make Blenderbot float16 compliant
pass
def test_xla_mode(self):
# TODO JP: Make Blenderbot XLA compliant
pass
def test_resize_token_embeddings(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

View File

@ -279,14 +279,6 @@ class TFBlenderbotSmallModelTest(TFModelTesterMixin, unittest.TestCase):
# This test is too long (>30sec) and makes fail the CI
pass
def test_mixed_precision(self):
# TODO JP: Make Blenderbot Small float16 compliant
pass
def test_xla_mode(self):
# TODO JP: Make Blenderbot Small XLA compliant
pass
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""

View File

@ -247,14 +247,6 @@ class TFMarianModelTest(TFModelTesterMixin, unittest.TestCase):
# This test is too long (>30sec) and makes fail the CI
pass
def test_mixed_precision(self):
# TODO JP: Make Marian float16 compliant
pass
def test_xla_mode(self):
# TODO JP: Make Marian XLA compliant
pass
def test_resize_token_embeddings(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

View File

@ -214,18 +214,6 @@ class TFMBartModelTest(TFModelTesterMixin, unittest.TestCase):
name = model.get_bias()
assert name is None
def test_saved_model_creation(self):
# This test is too long (>30sec) and makes fail the CI
pass
def test_mixed_precision(self):
# TODO JP: Make MBart float16 compliant
pass
def test_xla_mode(self):
# TODO JP: Make MBart XLA compliant
pass
def test_resize_token_embeddings(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@ -289,6 +277,10 @@ class TFMBartModelTest(TFModelTesterMixin, unittest.TestCase):
models_equal = False
self.assertTrue(models_equal)
def test_saved_model_creation(self):
# This test is too long (>30sec) and makes fail the CI
pass
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""

View File

@ -245,14 +245,6 @@ class TFPegasusModelTest(TFModelTesterMixin, unittest.TestCase):
# This test is too long (>30sec) and makes fail the CI
pass
def test_mixed_precision(self):
# TODO JP: Make Pegasus float16 compliant
pass
def test_xla_mode(self):
# TODO JP: Make Pegasus XLA compliant
pass
def test_resize_token_embeddings(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()