TFMarian, TFMbart, TFPegasus, TFBlenderbot (#7987)
* Start plumbing * Marian close * Small stubs for all children * Fixed bart * marian working * pegasus test is good, but failing * Checkin tests * More model files * Subtle marian, pegasus integration test failures * Works well * rm print * boom boom * Still failing model2doc * merge master * Equivalence test failing, all others fixed * cleanup * Fix embed_scale * Cleanup marian pipeline test * Undo extra changes * Smaller delta * Cleanup model testers * undo delta * fix tests import structure * cross test decorator * Cleaner set_weights * Respect authorized_unexpected_keys * No warnings * No warnings * style * Nest tf import * black * Apply suggestions from code review Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * functional dropout * fixup * Fixup * style_doc * embs * shape list * delete slow force_token_id_to_be_generated func * fixup Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
parent
6279072f5f
commit
566b083eb1
|
@ -95,3 +95,12 @@ See :obj:`transformers.BartForConditionalGeneration` for arguments to `forward`
|
|||
|
||||
.. autoclass:: transformers.BlenderbotForConditionalGeneration
|
||||
:members:
|
||||
|
||||
|
||||
TFBlenderbotForConditionalGeneration
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
See :obj:`transformers.TFBartForConditionalGeneration` for arguments to `forward` and `generate`
|
||||
|
||||
.. autoclass:: transformers.TFBlenderbotForConditionalGeneration
|
||||
:members:
|
||||
|
|
|
@ -129,3 +129,9 @@ MarianMTModel
|
|||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.MarianMTModel
|
||||
|
||||
|
||||
TFMarianMTModel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFMarianMTModel
|
||||
|
|
|
@ -79,4 +79,11 @@ MBartForConditionalGeneration
|
|||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.MBartForConditionalGeneration
|
||||
:members: forward
|
||||
:members:
|
||||
|
||||
|
||||
TFMBartForConditionalGeneration
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFMBartForConditionalGeneration
|
||||
:members:
|
||||
|
|
|
@ -95,3 +95,9 @@ PegasusForConditionalGeneration
|
|||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.PegasusForConditionalGeneration
|
||||
|
||||
|
||||
TFPegasusForConditionalGeneration
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFPegasusForConditionalGeneration
|
||||
|
|
|
@ -670,6 +670,7 @@ if is_tf_available():
|
|||
TFBertModel,
|
||||
TFBertPreTrainedModel,
|
||||
)
|
||||
from .modeling_tf_blenderbot import TFBlenderbotForConditionalGeneration
|
||||
from .modeling_tf_camembert import (
|
||||
TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TFCamembertForMaskedLM,
|
||||
|
@ -750,6 +751,8 @@ if is_tf_available():
|
|||
TFLxmertPreTrainedModel,
|
||||
TFLxmertVisualFeatureEncoder,
|
||||
)
|
||||
from .modeling_tf_marian import TFMarianMTModel
|
||||
from .modeling_tf_mbart import TFMBartForConditionalGeneration
|
||||
from .modeling_tf_mobilebert import (
|
||||
TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TFMobileBertForMaskedLM,
|
||||
|
@ -771,6 +774,7 @@ if is_tf_available():
|
|||
TFOpenAIGPTModel,
|
||||
TFOpenAIGPTPreTrainedModel,
|
||||
)
|
||||
from .modeling_tf_pegasus import TFPegasusForConditionalGeneration
|
||||
from .modeling_tf_roberta import (
|
||||
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TFRobertaForMaskedLM,
|
||||
|
|
|
@ -427,7 +427,6 @@ class DecoderLayer(nn.Module):
|
|||
output_attentions=False,
|
||||
):
|
||||
residual = x
|
||||
|
||||
if layer_state is None:
|
||||
layer_state = {}
|
||||
if self.normalize_before:
|
||||
|
@ -447,7 +446,7 @@ class DecoderLayer(nn.Module):
|
|||
if not self.normalize_before:
|
||||
x = self.self_attn_layer_norm(x)
|
||||
|
||||
# Cross attention
|
||||
# Cross-Attention Block
|
||||
residual = x
|
||||
assert self.encoder_attn.cache_key != self.self_attn.cache_key
|
||||
if self.normalize_before:
|
||||
|
@ -628,7 +627,6 @@ class BartDecoder(nn.Module):
|
|||
encoder_hidden_states = encoder_hidden_states.transpose(0, 1)
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [x, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||
return BaseModelOutputWithPast(
|
||||
|
|
|
@ -41,6 +41,10 @@ from .configuration_auto import (
|
|||
XLNetConfig,
|
||||
replace_list_option_in_docstrings,
|
||||
)
|
||||
from .configuration_blenderbot import BlenderbotConfig
|
||||
from .configuration_marian import MarianConfig
|
||||
from .configuration_mbart import MBartConfig
|
||||
from .configuration_pegasus import PegasusConfig
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .file_utils import add_start_docstrings
|
||||
from .modeling_tf_albert import (
|
||||
|
@ -63,6 +67,7 @@ from .modeling_tf_bert import (
|
|||
TFBertLMHeadModel,
|
||||
TFBertModel,
|
||||
)
|
||||
from .modeling_tf_blenderbot import TFBlenderbotForConditionalGeneration
|
||||
from .modeling_tf_camembert import (
|
||||
TFCamembertForMaskedLM,
|
||||
TFCamembertForMultipleChoice,
|
||||
|
@ -108,6 +113,8 @@ from .modeling_tf_funnel import (
|
|||
)
|
||||
from .modeling_tf_gpt2 import TFGPT2LMHeadModel, TFGPT2Model
|
||||
from .modeling_tf_longformer import TFLongformerForMaskedLM, TFLongformerForQuestionAnswering, TFLongformerModel
|
||||
from .modeling_tf_marian import TFMarianMTModel
|
||||
from .modeling_tf_mbart import TFMBartForConditionalGeneration
|
||||
from .modeling_tf_mobilebert import (
|
||||
TFMobileBertForMaskedLM,
|
||||
TFMobileBertForMultipleChoice,
|
||||
|
@ -118,6 +125,7 @@ from .modeling_tf_mobilebert import (
|
|||
TFMobileBertModel,
|
||||
)
|
||||
from .modeling_tf_openai import TFOpenAIGPTLMHeadModel, TFOpenAIGPTModel
|
||||
from .modeling_tf_pegasus import TFPegasusForConditionalGeneration
|
||||
from .modeling_tf_roberta import (
|
||||
TFRobertaForMaskedLM,
|
||||
TFRobertaForMultipleChoice,
|
||||
|
@ -210,6 +218,7 @@ TF_MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
|
|||
(T5Config, TFT5ForConditionalGeneration),
|
||||
(DistilBertConfig, TFDistilBertForMaskedLM),
|
||||
(AlbertConfig, TFAlbertForMaskedLM),
|
||||
(MarianConfig, TFMarianMTModel),
|
||||
(BartConfig, TFBartForConditionalGeneration),
|
||||
(CamembertConfig, TFCamembertForMaskedLM),
|
||||
(XLMRobertaConfig, TFXLMRobertaForMaskedLM),
|
||||
|
@ -261,8 +270,16 @@ TF_MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
|
|||
]
|
||||
)
|
||||
|
||||
|
||||
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict(
|
||||
[(T5Config, TFT5ForConditionalGeneration), (BartConfig, TFBartForConditionalGeneration)]
|
||||
[
|
||||
(T5Config, TFT5ForConditionalGeneration),
|
||||
(MarianConfig, TFMarianMTModel),
|
||||
(MBartConfig, TFMBartForConditionalGeneration),
|
||||
(PegasusConfig, TFPegasusForConditionalGeneration),
|
||||
(BlenderbotConfig, TFBlenderbotForConditionalGeneration),
|
||||
(BartConfig, TFBartForConditionalGeneration),
|
||||
]
|
||||
)
|
||||
|
||||
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
|
||||
|
|
|
@ -19,9 +19,10 @@ import random
|
|||
import warnings
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from tensorflow import Tensor
|
||||
from tensorflow.keras.layers import Dense, LayerNormalization
|
||||
from tensorflow.keras.layers import Dense, Layer, LayerNormalization
|
||||
|
||||
from .activations_tf import ACT2FN
|
||||
from .configuration_bart import BartConfig
|
||||
|
@ -43,7 +44,6 @@ from .utils import logging
|
|||
|
||||
|
||||
_CONFIG_FOR_DOC = "BartConfig"
|
||||
_TOKENIZER_FOR_DOC = "BartTokenizer"
|
||||
|
||||
BART_START_DOCSTRING = r"""
|
||||
|
||||
|
@ -218,22 +218,21 @@ PAST_KV_DEPRECATION_WARNING = (
|
|||
)
|
||||
|
||||
|
||||
class TFEncoderLayer(tf.keras.layers.Layer):
|
||||
class TFEncoderLayer(Layer):
|
||||
def __init__(self, config: BartConfig, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.embed_dim = config.d_model
|
||||
self.self_attn = TFAttention(
|
||||
self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout, name="self_attn"
|
||||
)
|
||||
|
||||
self.self_attn_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm")
|
||||
self.dropout_wt = tf.keras.layers.Dropout(config.dropout)
|
||||
self.normalize_before = config.normalize_before
|
||||
self.self_attn_layer_norm = LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm")
|
||||
self.dropout = config.dropout
|
||||
self.activation_fn = ACT2FN[config.activation_function]
|
||||
self.activation_dropout = tf.keras.layers.Dropout(config.activation_dropout)
|
||||
self.activation_dropout = config.activation_dropout
|
||||
self.fc1 = Dense(config.encoder_ffn_dim, name="fc1")
|
||||
self.fc2 = Dense(self.embed_dim, name="fc2")
|
||||
self.final_layer_norm = LayerNormalization(epsilon=1e-5, name="final_layer_norm")
|
||||
self.normalize_before = config.normalize_before
|
||||
|
||||
def call(self, x, encoder_padding_mask, training=False):
|
||||
"""
|
||||
|
@ -251,8 +250,10 @@ class TFEncoderLayer(tf.keras.layers.Layer):
|
|||
if self.normalize_before:
|
||||
x = self.self_attn_layer_norm(x)
|
||||
x, self_attn_weights = self.self_attn(query=x, key=x, key_padding_mask=encoder_padding_mask)
|
||||
assert x.shape == residual.shape, f"Self attn modified the shape of query {residual.shape} to {x.shape}"
|
||||
x = self.dropout_wt(x, training=training)
|
||||
assert shape_list(x) == shape_list(
|
||||
residual
|
||||
), f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(x)}"
|
||||
x = tf.nn.dropout(x, rate=self.dropout if training else 0)
|
||||
x = residual + x
|
||||
if not self.normalize_before:
|
||||
x = self.self_attn_layer_norm(x)
|
||||
|
@ -261,9 +262,9 @@ class TFEncoderLayer(tf.keras.layers.Layer):
|
|||
if self.normalize_before:
|
||||
x = self.final_layer_norm(x)
|
||||
x = self.activation_fn(self.fc1(x))
|
||||
x = self.activation_dropout(x, training=training)
|
||||
x = tf.nn.dropout(x, rate=self.self.activation_dropout if training else 0)
|
||||
x = self.fc2(x)
|
||||
x = self.dropout_wt(x, training=training)
|
||||
x = tf.nn.dropout(x, rate=self.dropout if training else 0)
|
||||
x = residual + x
|
||||
if not self.normalize_before:
|
||||
x = self.final_layer_norm(x)
|
||||
|
@ -271,7 +272,7 @@ class TFEncoderLayer(tf.keras.layers.Layer):
|
|||
return x, self_attn_weights
|
||||
|
||||
|
||||
class TFBartEncoder(tf.keras.layers.Layer):
|
||||
class TFBartEncoder(Layer):
|
||||
# config_class = BartConfig
|
||||
"""
|
||||
Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
|
||||
|
@ -289,26 +290,30 @@ class TFBartEncoder(tf.keras.layers.Layer):
|
|||
self.output_hidden_states = config.output_hidden_states
|
||||
self.output_attentions = config.output_attentions
|
||||
|
||||
embed_dim = embed_tokens.vocab_size
|
||||
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
||||
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.max_source_positions = config.max_position_embeddings
|
||||
|
||||
self.embed_tokens = embed_tokens
|
||||
self.embed_positions = TFLearnedPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
embed_tokens.hidden_size,
|
||||
self.padding_idx,
|
||||
config.extra_pos_embeddings,
|
||||
name="embed_positions",
|
||||
)
|
||||
if config.static_position_embeddings:
|
||||
self.embed_positions = TFSinusoidalPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
config.d_model,
|
||||
name="embed_positions",
|
||||
)
|
||||
else:
|
||||
self.embed_positions = TFLearnedPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
config.d_model,
|
||||
self.padding_idx,
|
||||
config.extra_pos_embeddings,
|
||||
name="embed_positions",
|
||||
)
|
||||
self.layers = [TFEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)]
|
||||
self.layernorm_embedding = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding")
|
||||
self.layer_norm = (
|
||||
tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm")
|
||||
if config.add_final_layer_norm
|
||||
else None
|
||||
self.layernorm_embedding = (
|
||||
LayerNormalization(epsilon=1e-5, name="layernorm_embedding") if config.normalize_embedding else Layer()
|
||||
)
|
||||
self.layer_norm = LayerNormalization(epsilon=1e-5, name="layer_norm") if config.add_final_layer_norm else None
|
||||
self.return_dict = config.return_dict
|
||||
|
||||
def call(
|
||||
|
@ -347,7 +352,7 @@ class TFBartEncoder(tf.keras.layers.Layer):
|
|||
), f"expected attention_mask._rank() to be a 2D tensor got {attention_mask._rank()}"
|
||||
attention_mask = tf.cast(attention_mask, dtype=tf.float32)
|
||||
attention_mask = (1.0 - attention_mask) * LARGE_NEGATIVE
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||
embed_pos = self.embed_positions(input_ids)
|
||||
x = inputs_embeds + embed_pos
|
||||
x = self.layernorm_embedding(x)
|
||||
|
@ -384,7 +389,7 @@ class TFBartEncoder(tf.keras.layers.Layer):
|
|||
return TFBaseModelOutput(last_hidden_state=x, hidden_states=encoder_states, attentions=all_attentions)
|
||||
|
||||
|
||||
class TFDecoderLayer(tf.keras.layers.Layer):
|
||||
class TFDecoderLayer(Layer):
|
||||
def __init__(self, config: BartConfig, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.embed_dim = config.d_model
|
||||
|
@ -397,8 +402,9 @@ class TFDecoderLayer(tf.keras.layers.Layer):
|
|||
self.dropout = config.dropout
|
||||
self.activation_fn = ACT2FN[config.activation_function]
|
||||
self.activation_dropout = config.activation_dropout
|
||||
self.normalize_before = config.normalize_before
|
||||
|
||||
self.self_attn_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm")
|
||||
self.self_attn_layer_norm = LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm")
|
||||
self.encoder_attn = TFAttention(
|
||||
self.embed_dim,
|
||||
config.decoder_attention_heads,
|
||||
|
@ -406,10 +412,10 @@ class TFDecoderLayer(tf.keras.layers.Layer):
|
|||
encoder_decoder_attention=True,
|
||||
name="encoder_attn",
|
||||
)
|
||||
self.encoder_attn_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="encoder_attn_layer_norm")
|
||||
self.encoder_attn_layer_norm = LayerNormalization(epsilon=1e-5, name="encoder_attn_layer_norm")
|
||||
self.fc1 = Dense(config.decoder_ffn_dim, name="fc1")
|
||||
self.fc2 = Dense(self.embed_dim, name="fc2")
|
||||
self.final_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
|
||||
self.final_layer_norm = LayerNormalization(epsilon=1e-5, name="final_layer_norm")
|
||||
|
||||
def call(
|
||||
self,
|
||||
|
@ -433,10 +439,12 @@ class TFDecoderLayer(tf.keras.layers.Layer):
|
|||
|
||||
Tuple containing, encoded output of shape `(seq_len, batch, embed_dim)`, self_attn_weights, layer_state
|
||||
"""
|
||||
residual = x # Make a copy of the input tensor to add later.
|
||||
if layer_state is None:
|
||||
layer_state = {}
|
||||
if self.normalize_before:
|
||||
x = self.self_attn_layer_norm(x)
|
||||
|
||||
residual = x # Make a copy of the input tensor to add later.
|
||||
# next line mutates layer state and we need a copy of it
|
||||
x, self_attn_weights = self.self_attn(
|
||||
query=x,
|
||||
|
@ -447,9 +455,12 @@ class TFDecoderLayer(tf.keras.layers.Layer):
|
|||
)
|
||||
x = tf.nn.dropout(x, rate=self.dropout if training else 0)
|
||||
x = residual + x
|
||||
x = self.self_attn_layer_norm(x)
|
||||
if not self.normalize_before:
|
||||
x = self.self_attn_layer_norm(x)
|
||||
# Cross-Attention Block
|
||||
residual = x
|
||||
# Cross-Attention
|
||||
if self.normalize_before:
|
||||
x = self.encoder_attn_layer_norm(x)
|
||||
x, _ = self.encoder_attn(
|
||||
query=x,
|
||||
key=encoder_hidden_states,
|
||||
|
@ -458,16 +469,19 @@ class TFDecoderLayer(tf.keras.layers.Layer):
|
|||
)
|
||||
x = tf.nn.dropout(x, rate=self.dropout if training else 0)
|
||||
x = residual + x
|
||||
|
||||
x = self.encoder_attn_layer_norm(x)
|
||||
|
||||
if not self.normalize_before:
|
||||
x = self.encoder_attn_layer_norm(x)
|
||||
# Fully Connected
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.final_layer_norm(x)
|
||||
x = self.activation_fn(self.fc1(x))
|
||||
x = tf.nn.dropout(x, rate=self.activation_dropout if training else 0)
|
||||
x = self.fc2(x)
|
||||
x = tf.nn.dropout(x, rate=self.dropout if training else 0)
|
||||
x = residual + x
|
||||
x = self.final_layer_norm(x)
|
||||
if not self.normalize_before:
|
||||
x = self.final_layer_norm(x)
|
||||
return (
|
||||
x,
|
||||
self_attn_weights,
|
||||
|
@ -475,7 +489,7 @@ class TFDecoderLayer(tf.keras.layers.Layer):
|
|||
) # just self_attn weights for now, following t5, layer_state = cache for decoding
|
||||
|
||||
|
||||
class TFBartDecoder(tf.keras.layers.Layer):
|
||||
class TFBartDecoder(Layer):
|
||||
"""
|
||||
Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a :class:`TFDecoderLayer`
|
||||
|
||||
|
@ -491,26 +505,27 @@ class TFBartDecoder(tf.keras.layers.Layer):
|
|||
self.max_target_positions = config.max_position_embeddings
|
||||
self.embed_tokens = embed_tokens
|
||||
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
||||
self.embed_positions = TFLearnedPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
config.d_model,
|
||||
self.padding_idx,
|
||||
config.extra_pos_embeddings,
|
||||
name="embed_positions",
|
||||
)
|
||||
if config.static_position_embeddings:
|
||||
self.embed_positions = TFSinusoidalPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
config.d_model,
|
||||
name="embed_positions",
|
||||
)
|
||||
else:
|
||||
self.embed_positions = TFLearnedPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
config.d_model,
|
||||
self.padding_idx,
|
||||
config.extra_pos_embeddings,
|
||||
name="embed_positions",
|
||||
)
|
||||
self.layers = [TFDecoderLayer(config, name=f"layers.{i}") for i in range(config.decoder_layers)]
|
||||
self.layernorm_embedding = (
|
||||
tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding")
|
||||
if config.normalize_embedding
|
||||
else tf.identity
|
||||
)
|
||||
self.layer_norm = (
|
||||
tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm")
|
||||
if config.add_final_layer_norm
|
||||
else None
|
||||
LayerNormalization(epsilon=1e-5, name="layernorm_embedding") if config.normalize_embedding else Layer()
|
||||
)
|
||||
self.layer_norm = LayerNormalization(epsilon=1e-5, name="layer_norm") if config.add_final_layer_norm else None
|
||||
|
||||
self.dropout = tf.keras.layers.Dropout(config.dropout)
|
||||
self.dropout = config.dropout
|
||||
self.output_hidden_states = config.output_hidden_states
|
||||
self.output_attentions = config.output_attentions
|
||||
self.use_cache = config.use_cache
|
||||
|
@ -553,11 +568,11 @@ class TFBartDecoder(tf.keras.layers.Layer):
|
|||
x = self.layernorm_embedding(x) + positions
|
||||
else:
|
||||
x = self.layernorm_embedding(x + positions)
|
||||
x = self.dropout(x)
|
||||
x = tf.nn.dropout(x, rate=self.dropout if training else 0)
|
||||
|
||||
# Convert to Bart output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim)
|
||||
x = tf.transpose(x, perm=(1, 0, 2))
|
||||
assert len(encoder_hidden_states.shape) == 3, "encoder_hidden_states must be a 3D tensor"
|
||||
assert len(shape_list(encoder_hidden_states)) == 3, "encoder_hidden_states must be a 3D tensor"
|
||||
encoder_hidden_states = tf.transpose(encoder_hidden_states, perm=(1, 0, 2))
|
||||
|
||||
# decoder layers
|
||||
|
@ -623,7 +638,7 @@ def _reorder_buffer(attn_cache, new_order):
|
|||
return attn_cache
|
||||
|
||||
|
||||
class TFAttention(tf.keras.layers.Layer):
|
||||
class TFAttention(Layer):
|
||||
"""Multi-headed attention from "Attention Is All You Need"""
|
||||
|
||||
def __init__(
|
||||
|
@ -678,8 +693,10 @@ class TFAttention(tf.keras.layers.Layer):
|
|||
(default: None).
|
||||
"""
|
||||
static_kv = self.encoder_decoder_attention # value=key=encoder_hidden_states,
|
||||
tgt_len, bsz, embed_dim = query.shape
|
||||
assert embed_dim == self.embed_dim, f"query must be shaped {(tgt_len, bsz, self.embed_dim)} got {query.shape}"
|
||||
tgt_len, bsz, embed_dim = shape_list(query)
|
||||
assert (
|
||||
embed_dim == self.embed_dim
|
||||
), f"query must be shaped {(tgt_len, bsz, self.embed_dim)} got {shape_list(query)}"
|
||||
# get here for encoder decoder cause of static_kv
|
||||
if layer_state is not None: # get the last k and v for reuse
|
||||
saved_state = layer_state.get(self.cache_key, {})
|
||||
|
@ -718,7 +735,7 @@ class TFAttention(tf.keras.layers.Layer):
|
|||
)
|
||||
|
||||
# Compute multi-headed attention
|
||||
src_len = k.shape[1]
|
||||
src_len = shape_list(k)[1]
|
||||
attn_weights = tf.matmul(q, k, transpose_b=True) # shape (bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
if attn_mask is not None:
|
||||
|
@ -770,7 +787,7 @@ class TFLearnedPositionalEmbedding(TFSharedEmbeddings):
|
|||
|
||||
def call(self, input_ids: tf.Tensor, use_cache=False):
|
||||
"""Input is expected to be of size [bsz x seqlen]."""
|
||||
bsz, seq_len = input_ids.shape[:2]
|
||||
bsz, seq_len = shape_list(input_ids)[:2]
|
||||
|
||||
if use_cache:
|
||||
positions = tf.fill((1, 1), seq_len - 1)
|
||||
|
@ -780,6 +797,56 @@ class TFLearnedPositionalEmbedding(TFSharedEmbeddings):
|
|||
return super().call(positions + self.offset) # super object is not callable for some reason
|
||||
|
||||
|
||||
class TFSinusoidalPositionalEmbedding(tf.keras.layers.Embedding):
|
||||
"""This module produces sinusoidal positional embeddings of any length."""
|
||||
|
||||
def __init__(self, num_positions, embedding_dim, **kwargs):
|
||||
|
||||
if embedding_dim % 2 != 0:
|
||||
raise NotImplementedError(f"odd embedding_dim {embedding_dim} not supported")
|
||||
super().__init__(
|
||||
num_positions,
|
||||
embedding_dim,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def build(self, input_shape):
|
||||
"""
|
||||
Build shared token embedding layer Shared weights logic adapted from
|
||||
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
|
||||
"""
|
||||
super().build(input_shape) # Instantiates self.weight so it can be loaded
|
||||
weight: np.ndarray = self._init_weight(self.input_dim, self.output_dim)
|
||||
self.set_weights([weight]) # overwrite self.weight to correct value
|
||||
|
||||
@staticmethod
|
||||
def _init_weight(n_pos, dim):
|
||||
"""
|
||||
Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in
|
||||
the 2nd half of the vector. [dim // 2:]
|
||||
"""
|
||||
position_enc = np.array(
|
||||
[[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]
|
||||
)
|
||||
# index 0 is all zero
|
||||
position_enc[:, 0 : dim // 2] = np.sin(position_enc[:, 0::2])
|
||||
position_enc[:, dim // 2 :] = np.cos(position_enc[:, 1::2])
|
||||
# convert to tensor
|
||||
table = tf.convert_to_tensor(position_enc, dtype=tf.float32)
|
||||
tf.stop_gradient(table)
|
||||
return table
|
||||
|
||||
def call(self, input_ids, use_cache=False):
|
||||
"""Input is expected to be of size [bsz x seqlen]."""
|
||||
bsz, seq_len = shape_list(input_ids)[:2]
|
||||
if use_cache:
|
||||
positions = tf.fill((1, 1), seq_len - 1)
|
||||
else:
|
||||
# starts at 0, ends at 1-seq_len
|
||||
positions = tf.range(0, seq_len, delta=1, dtype=tf.int32, name="range")
|
||||
return super().call(positions)
|
||||
|
||||
|
||||
# Public API
|
||||
|
||||
|
||||
|
@ -818,7 +885,7 @@ class TFBartModel(TFPretrainedBartModel):
|
|||
pad_token_id = self.config.pad_token_id
|
||||
if decoder_input_ids is None:
|
||||
decoder_input_ids = self._shift_right(inputs)
|
||||
bsz, tgt_len = decoder_input_ids.shape[:2]
|
||||
bsz, tgt_len = shape_list(decoder_input_ids)[:2]
|
||||
if decoder_attn_mask is None:
|
||||
decoder_padding_mask = make_padding_mask(decoder_input_ids, pad_token_id)
|
||||
else:
|
||||
|
@ -950,16 +1017,20 @@ class TFBartForConditionalGeneration(TFPretrainedBartModel):
|
|||
base_model_prefix = "model"
|
||||
authorized_missing_keys = [
|
||||
r"final_logits_bias",
|
||||
r"encoder\.version",
|
||||
r"decoder\.version",
|
||||
"model.encoder.embed_tokens.weight",
|
||||
"model.decoder.embed_tokens.weight",
|
||||
]
|
||||
authorized_unexpected_keys = [
|
||||
r"model.encoder.embed_tokens.weight",
|
||||
r"model.decoder.embed_tokens.weight",
|
||||
]
|
||||
|
||||
def __init__(self, config: BartConfig, *args, **kwargs):
|
||||
super().__init__(config, *args, **kwargs)
|
||||
self.model = TFBartModel(config, name="model")
|
||||
self.use_cache = config.use_cache
|
||||
# final_bias_logits is registered as a buffer in pytorch, so not trainable for the the sake of consistency.
|
||||
self.final_logits_bias = self.add_weight(
|
||||
name="/final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False
|
||||
)
|
||||
|
||||
@add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||
|
@ -1050,6 +1121,7 @@ class TFBartForConditionalGeneration(TFPretrainedBartModel):
|
|||
return_dict=True, # TODO(SS): this may need to change to support compilation
|
||||
)
|
||||
logits = self.model.shared(outputs.last_hidden_state, mode="linear")
|
||||
logits = logits + self.final_logits_bias
|
||||
loss = None if labels is None else self.compute_loss(labels, logits)
|
||||
|
||||
past = outputs.past_key_values if cast_bool_to_primitive(use_cache, self.config.use_cache) else None
|
||||
|
@ -1096,7 +1168,7 @@ class TFBartForConditionalGeneration(TFPretrainedBartModel):
|
|||
), f"decoder cached states must be truthy. got {decoder_cached_states} from the 2nd element of past"
|
||||
assert isinstance(
|
||||
encoder_outputs, TFBaseModelOutput
|
||||
), "encoder_outputs should be a TFBaseModelOutput, Instead got "
|
||||
), f"encoder_outputs should be a TFBaseModelOutput, Instead got {type(encoder_outputs)}."
|
||||
return {
|
||||
"inputs": None, # encoder_outputs is defined. input_ids not needed
|
||||
"encoder_outputs": encoder_outputs,
|
||||
|
@ -1113,7 +1185,6 @@ class TFBartForConditionalGeneration(TFPretrainedBartModel):
|
|||
reordered_past = []
|
||||
for layer_past in decoder_cached_states:
|
||||
# get the correct batch idx from decoder layer's batch dim for cross and self-attn
|
||||
|
||||
layer_past_new = {
|
||||
attn_key: _reorder_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items()
|
||||
}
|
||||
|
@ -1124,26 +1195,13 @@ class TFBartForConditionalGeneration(TFPretrainedBartModel):
|
|||
|
||||
def adjust_logits_during_generation(self, logits, cur_len, max_length):
|
||||
if cur_len == 1 and self.config.force_bos_token_to_be_generated:
|
||||
logits = self._force_token_id_to_be_generated(logits, self.config.bos_token_id)
|
||||
elif cur_len == max_length - 1 and self.config.eos_token_id is not None:
|
||||
logits = self._force_token_id_to_be_generated(logits, self.config.eos_token_id)
|
||||
return logits
|
||||
|
||||
@staticmethod
|
||||
def _force_token_id_to_be_generated(scores, token_id) -> None:
|
||||
"""force one of token_ids to be generated by setting prob of all other tokens to 0 (logprob=-float("inf"))"""
|
||||
output_list = []
|
||||
|
||||
# Is there a better way to do scores[:, [x for if x != token_id]] = -float("inf") in TF?
|
||||
bs, vocab_size = scores.shape
|
||||
for x in range(vocab_size):
|
||||
if x != token_id:
|
||||
output_list.append(tf.convert_to_tensor([-float("inf")] * bs, dtype=scores.dtype))
|
||||
else:
|
||||
output_list.append(scores[:, x])
|
||||
scores = tf.stack(output_list, axis=1, name="scores")
|
||||
assert scores.shape == (bs, vocab_size)
|
||||
return scores
|
||||
vocab_range = tf.constant(range(self.config.vocab_size))
|
||||
return tf.where(vocab_range != self.config.bos_token_id, LARGE_NEGATIVE, logits)
|
||||
elif cur_len == max_length - 1:
|
||||
vocab_range = tf.constant(range(self.config.vocab_size))
|
||||
return tf.where(vocab_range != self.config.eos_token_id, LARGE_NEGATIVE, logits)
|
||||
else:
|
||||
return logits
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.model.shared
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2020 The Facebook AI Research Team Authors and The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""TF BlenderBot model, ported from the fairseq repo."""
|
||||
from .configuration_blenderbot import BlenderbotConfig
|
||||
from .file_utils import add_start_docstrings, is_tf_available
|
||||
from .modeling_tf_bart import BART_START_DOCSTRING, LARGE_NEGATIVE, TFBartForConditionalGeneration
|
||||
from .utils import logging
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
_CONFIG_FOR_DOC = "BlenderbotConfig"
|
||||
|
||||
START_DOCSTRING = BART_START_DOCSTRING.replace(
|
||||
"inherits from :class:`~transformers.TFPreTrainedModel`",
|
||||
"inherits from :class:`~transformers.TFBartForConditionalGeneration`",
|
||||
).replace("BartConfig", _CONFIG_FOR_DOC)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@add_start_docstrings("Blenderbot model for open domain dialogue", START_DOCSTRING)
|
||||
class TFBlenderbotForConditionalGeneration(TFBartForConditionalGeneration):
|
||||
config_class = BlenderbotConfig
|
||||
|
||||
def adjust_logits_during_generation(self, logits, cur_len, max_length):
|
||||
"""Never predict pad_token_id. Predict </s> when max_length is reached."""
|
||||
vocab_range = tf.constant(range(self.config.vocab_size))
|
||||
logits = tf.where(vocab_range == self.config.pad_token_id, LARGE_NEGATIVE, logits)
|
||||
if cur_len == max_length - 1:
|
||||
logits = tf.where(vocab_range != self.config.eos_token_id, LARGE_NEGATIVE, logits)
|
||||
return logits
|
|
@ -0,0 +1,52 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2020 The Facebook AI Research Team Authors and The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""TF Marian model, ported from the fairseq repo."""
|
||||
|
||||
from .configuration_marian import MarianConfig
|
||||
from .file_utils import add_start_docstrings, is_tf_available
|
||||
from .modeling_tf_bart import BART_START_DOCSTRING, LARGE_NEGATIVE, TFBartForConditionalGeneration
|
||||
from .utils import logging
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
_CONFIG_FOR_DOC = "MarianConfig"
|
||||
|
||||
START_DOCSTRING = BART_START_DOCSTRING.replace(
|
||||
"inherits from :class:`~transformers.TFPreTrainedModel`",
|
||||
"inherits from :class:`~transformers.TFBartForConditionalGeneration`",
|
||||
).replace("BartConfig", _CONFIG_FOR_DOC)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@add_start_docstrings("Marian model for machine translation", START_DOCSTRING)
|
||||
class TFMarianMTModel(TFBartForConditionalGeneration):
|
||||
authorized_missing_keys = [
|
||||
r"model.encoder.embed_positions.weight",
|
||||
r"model.decoder.embed_positions.weight",
|
||||
]
|
||||
config_class = MarianConfig
|
||||
|
||||
def adjust_logits_during_generation(self, logits, cur_len, max_length):
|
||||
"""Never predict pad_token_id. Predict </s> when max_length is reached."""
|
||||
vocab_range = tf.constant(range(self.config.vocab_size))
|
||||
logits = tf.where(vocab_range == self.config.pad_token_id, LARGE_NEGATIVE, logits)
|
||||
if cur_len == max_length - 1:
|
||||
logits = tf.where(vocab_range != self.config.eos_token_id, LARGE_NEGATIVE, logits)
|
||||
return logits
|
|
@ -0,0 +1,36 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2020 The Facebook AI Research Team Authors and The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""TF mBART model, originally from fairseq."""
|
||||
from .configuration_mbart import MBartConfig
|
||||
from .file_utils import add_start_docstrings
|
||||
from .modeling_tf_bart import BART_START_DOCSTRING, TFBartForConditionalGeneration
|
||||
from .utils import logging
|
||||
|
||||
|
||||
_CONFIG_FOR_DOC = "MBartConfig"
|
||||
|
||||
START_DOCSTRING = BART_START_DOCSTRING.replace(
|
||||
"inherits from :class:`~transformers.TFPreTrainedModel`",
|
||||
"inherits from :class:`~transformers.TFBartForConditionalGeneration`",
|
||||
).replace("BartConfig", _CONFIG_FOR_DOC)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@add_start_docstrings("mBART (multilingual BART) model for machine translation", START_DOCSTRING)
|
||||
class TFMBartForConditionalGeneration(TFBartForConditionalGeneration):
|
||||
config_class = MBartConfig
|
||||
# All the code is in src/transformers/modeling_tf_bart.py
|
|
@ -0,0 +1,41 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2020 The Facebook AI Research Team Authors and The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""TF Pegasus model, ported from the fairseq repo."""
|
||||
from .configuration_pegasus import PegasusConfig
|
||||
from .file_utils import add_start_docstrings
|
||||
from .modeling_tf_bart import BART_START_DOCSTRING, TFBartForConditionalGeneration
|
||||
from .utils import logging
|
||||
|
||||
|
||||
_CONFIG_FOR_DOC = "PegasusConfig"
|
||||
|
||||
START_DOCSTRING = BART_START_DOCSTRING.replace(
|
||||
"inherits from :class:`~transformers.TFPreTrainedModel`",
|
||||
"inherits from :class:`~transformers.TFBartForConditionalGeneration`",
|
||||
).replace("BartConfig", _CONFIG_FOR_DOC)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@add_start_docstrings("Pegasus model for summarization", START_DOCSTRING)
|
||||
class TFPegasusForConditionalGeneration(TFBartForConditionalGeneration):
|
||||
authorized_missing_keys = [
|
||||
r"final_logits_bias",
|
||||
r"model.encoder.embed_positions.weight",
|
||||
r"model.decoder.embed_positions.weight",
|
||||
]
|
||||
config_class = PegasusConfig
|
||||
# All the code is in src/transformers/modeling_tf_bart.py
|
|
@ -325,6 +325,15 @@ class TFBertPreTrainedModel:
|
|||
requires_tf(self)
|
||||
|
||||
|
||||
class TFBlenderbotForConditionalGeneration:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_tf(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_tf(self)
|
||||
|
||||
|
||||
TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
|
@ -797,6 +806,24 @@ class TFLxmertVisualFeatureEncoder:
|
|||
requires_tf(self)
|
||||
|
||||
|
||||
class TFMarianMTModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_tf(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_tf(self)
|
||||
|
||||
|
||||
class TFMBartForConditionalGeneration:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_tf(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_tf(self)
|
||||
|
||||
|
||||
TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
|
@ -922,6 +949,15 @@ class TFOpenAIGPTPreTrainedModel:
|
|||
requires_tf(self)
|
||||
|
||||
|
||||
class TFPegasusForConditionalGeneration:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_tf(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_tf(self)
|
||||
|
||||
|
||||
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
|
|
|
@ -138,7 +138,7 @@ class MarianIntegrationTest(unittest.TestCase):
|
|||
)
|
||||
self.assertEqual(self.model.device, model_inputs.input_ids.device)
|
||||
generated_ids = self.model.generate(
|
||||
model_inputs.input_ids, attention_mask=model_inputs.attention_mask, num_beams=2
|
||||
model_inputs.input_ids, attention_mask=model_inputs.attention_mask, num_beams=2, max_length=128
|
||||
)
|
||||
generated_words = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
return generated_words
|
||||
|
@ -244,6 +244,8 @@ class TestMarian_RU_FR(MarianIntegrationTest):
|
|||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
class TestMarian_MT_EN(MarianIntegrationTest):
|
||||
"""Cover low resource/high perplexity setting. This breaks without adjust_logits_generation overwritten"""
|
||||
|
||||
src = "mt"
|
||||
tgt = "en"
|
||||
src_text = ["Billi messu b'mod ġentili, Ġesù fejjaq raġel li kien milqut bil - marda kerha tal - ġdiem."]
|
||||
|
|
|
@ -17,7 +17,9 @@
|
|||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import is_tf_available
|
||||
import numpy as np
|
||||
|
||||
from transformers import BartConfig, BartTokenizer, is_tf_available
|
||||
from transformers.file_utils import cached_property
|
||||
from transformers.testing_utils import is_pt_tf_cross_test, require_tf, slow
|
||||
|
||||
|
@ -28,12 +30,16 @@ from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
|||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
from transformers import BartConfig, TFBartForConditionalGeneration, TFBartModel
|
||||
from transformers.tokenization_bart import BartTokenizer
|
||||
from transformers import TFBartForConditionalGeneration, TFBartModel
|
||||
from transformers.modeling_tf_bart import TFSinusoidalPositionalEmbedding
|
||||
|
||||
|
||||
@require_tf
|
||||
class ModelTester:
|
||||
class TFBartModelTester:
|
||||
config_cls = BartConfig
|
||||
config_updates = {}
|
||||
hidden_act = "gelu"
|
||||
|
||||
def __init__(self, parent):
|
||||
self.parent = parent
|
||||
self.batch_size = 13
|
||||
|
@ -45,14 +51,13 @@ class ModelTester:
|
|||
self.num_hidden_layers = 5
|
||||
self.num_attention_heads = 4
|
||||
self.intermediate_size = 37
|
||||
self.hidden_act = "gelu"
|
||||
|
||||
self.hidden_dropout_prob = 0.1
|
||||
self.attention_probs_dropout_prob = 0.1
|
||||
self.max_position_embeddings = 20
|
||||
self.eos_token_ids = [2]
|
||||
self.pad_token_id = 1
|
||||
self.bos_token_id = 0
|
||||
# torch.manual_seed(0)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length - 1], self.vocab_size)
|
||||
|
@ -60,7 +65,7 @@ class ModelTester:
|
|||
input_ids = tf.concat([input_ids, eos_tensor], axis=1)
|
||||
input_ids = tf.clip_by_value(input_ids, 3, self.vocab_size + 1)
|
||||
|
||||
config = BartConfig(
|
||||
config = self.config_cls(
|
||||
vocab_size=self.vocab_size,
|
||||
d_model=self.hidden_size,
|
||||
encoder_layers=self.num_hidden_layers,
|
||||
|
@ -76,6 +81,7 @@ class ModelTester:
|
|||
bos_token_id=self.bos_token_id,
|
||||
pad_token_id=self.pad_token_id,
|
||||
decoder_start_token_id=self.pad_token_id,
|
||||
**self.config_updates,
|
||||
)
|
||||
inputs_dict = prepare_bart_inputs_dict(config, input_ids)
|
||||
return config, inputs_dict
|
||||
|
@ -101,9 +107,10 @@ class TestTFBart(TFModelTesterMixin, unittest.TestCase):
|
|||
all_generative_model_classes = (TFBartForConditionalGeneration,) if is_tf_available() else ()
|
||||
is_encoder_decoder = True
|
||||
test_pruning = False
|
||||
model_tester_cls = TFBartModelTester
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = ModelTester(self)
|
||||
self.model_tester = self.model_tester_cls(self)
|
||||
self.config_tester = ConfigTester(self, config_class=BartConfig)
|
||||
|
||||
def test_config(self):
|
||||
|
@ -120,7 +127,7 @@ class TestTFBart(TFModelTesterMixin, unittest.TestCase):
|
|||
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
|
||||
metric = tf.keras.metrics.SparseCategoricalAccuracy("accuracy")
|
||||
|
||||
model_class = TFBartForConditionalGeneration
|
||||
model_class = self.all_generative_model_classes[0]
|
||||
input_ids = {
|
||||
"decoder_input_ids": tf.keras.Input(batch_shape=(2, 2000), name="decoder_input_ids", dtype="int32"),
|
||||
"input_ids": tf.keras.Input(batch_shape=(2, 2000), name="input_ids", dtype="int32"),
|
||||
|
@ -354,3 +361,29 @@ class FasterTFBartModelIntegrationTests(unittest.TestCase):
|
|||
|
||||
expected = np.array([[-0.0828, -0.0251, -0.0674], [0.1277, 0.3311, -0.0255], [0.2613, -0.0840, -0.2763]])
|
||||
assert np.allclose(features[0, :3, :3].numpy(), expected, atol=1e-3)
|
||||
|
||||
|
||||
@require_tf
|
||||
class TestTFSinusoidalPositionalEmbeddings(unittest.TestCase):
|
||||
desired_weights = [
|
||||
[0, 0, 0, 0, 0],
|
||||
[0.84147096, 0.82177866, 0.80180490, 0.78165019, 0.76140374],
|
||||
[0.90929741, 0.93651021, 0.95829457, 0.97505713, 0.98720258],
|
||||
]
|
||||
|
||||
def test_positional_emb_cache_logic(self):
|
||||
input_ids = _long_tensor([[4, 10]])
|
||||
emb1 = TFSinusoidalPositionalEmbedding(num_positions=32, embedding_dim=6)
|
||||
no_cache = emb1(input_ids, use_cache=False)
|
||||
yes_cache = emb1(input_ids, use_cache=True)
|
||||
self.assertEqual((1, 1, 6), yes_cache.shape) # extra dim to allow broadcasting, feel free to delete!
|
||||
|
||||
np.testing.assert_almost_equal(no_cache[-1].numpy(), yes_cache[0][0].numpy())
|
||||
|
||||
def test_positional_emb_weights_against_marian(self):
|
||||
emb1 = TFSinusoidalPositionalEmbedding(num_positions=512, embedding_dim=512)
|
||||
emb1.build(None)
|
||||
weights = emb1.embeddings.numpy()
|
||||
for i, (expected_weight, actual_weight) in enumerate(zip(self.desired_weights, weights)):
|
||||
for j in range(5):
|
||||
self.assertAlmostEqual(expected_weight[j], actual_weight[j], places=3)
|
||||
|
|
|
@ -0,0 +1,132 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2020 HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from tests.test_configuration_common import ConfigTester
|
||||
from tests.test_modeling_tf_bart import TFBartModelTester
|
||||
from tests.test_modeling_tf_common import TFModelTesterMixin
|
||||
from transformers import BlenderbotConfig, BlenderbotSmallTokenizer, is_tf_available
|
||||
from transformers.file_utils import cached_property
|
||||
from transformers.testing_utils import is_pt_tf_cross_test, require_tf, require_tokenizers, slow
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
from transformers import TFAutoModelForSeq2SeqLM, TFBlenderbotForConditionalGeneration
|
||||
|
||||
|
||||
class ModelTester(TFBartModelTester):
|
||||
config_updates = dict(
|
||||
normalize_before=True,
|
||||
static_position_embeddings=True,
|
||||
do_blenderbot_90_layernorm=True,
|
||||
normalize_embeddings=True,
|
||||
)
|
||||
config_cls = BlenderbotConfig
|
||||
|
||||
|
||||
@require_tf
|
||||
class TestTFBlenderbotCommon(TFModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (TFBlenderbotForConditionalGeneration,) if is_tf_available() else ()
|
||||
all_generative_model_classes = (TFBlenderbotForConditionalGeneration,) if is_tf_available() else ()
|
||||
model_tester_cls = ModelTester
|
||||
is_encoder_decoder = True
|
||||
test_pruning = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = self.model_tester_cls(self)
|
||||
self.config_tester = ConfigTester(self, config_class=BlenderbotConfig)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_inputs_embeds(self):
|
||||
# inputs_embeds not supported
|
||||
pass
|
||||
|
||||
def test_saved_model_with_hidden_states_output(self):
|
||||
# Should be uncommented during patrick TF refactor
|
||||
pass
|
||||
|
||||
def test_saved_model_with_attentions_output(self):
|
||||
# Should be uncommented during patrick TF refactor
|
||||
pass
|
||||
|
||||
def test_compile_tf_model(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08, clipnorm=1.0)
|
||||
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
|
||||
metric = tf.keras.metrics.SparseCategoricalAccuracy("accuracy")
|
||||
|
||||
model_class = self.all_generative_model_classes[0]
|
||||
input_ids = {
|
||||
"decoder_input_ids": tf.keras.Input(batch_shape=(2, 2000), name="decoder_input_ids", dtype="int32"),
|
||||
"input_ids": tf.keras.Input(batch_shape=(2, 2000), name="input_ids", dtype="int32"),
|
||||
}
|
||||
|
||||
# Prepare our model
|
||||
model = model_class(config)
|
||||
model(self._prepare_for_class(inputs_dict, model_class)) # Model must be called before saving.
|
||||
# Let's load it from the disk to be sure we can use pretrained weights
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model = model_class.from_pretrained(tmpdirname)
|
||||
|
||||
outputs_dict = model(input_ids)
|
||||
hidden_states = outputs_dict[0]
|
||||
|
||||
# Add a dense layer on top to test integration with other keras modules
|
||||
outputs = tf.keras.layers.Dense(2, activation="softmax", name="outputs")(hidden_states)
|
||||
|
||||
# Compile extended model
|
||||
extended_model = tf.keras.Model(inputs=[input_ids], outputs=[outputs])
|
||||
extended_model.compile(optimizer=optimizer, loss=loss, metrics=[metric])
|
||||
|
||||
|
||||
@is_pt_tf_cross_test
|
||||
@require_tokenizers
|
||||
class TFBlenderbot90MIntegrationTests(unittest.TestCase):
|
||||
src_text = [
|
||||
"Social anxiety\nWow, I am never shy. Do you have anxiety?\nYes. I end up sweating and blushing and feel like i'm going to throw up.\nand why is that?"
|
||||
]
|
||||
model_name = "facebook/blenderbot-90M"
|
||||
|
||||
@cached_property
|
||||
def tokenizer(self):
|
||||
return BlenderbotSmallTokenizer.from_pretrained(self.model_name)
|
||||
|
||||
@cached_property
|
||||
def model(self):
|
||||
model = TFAutoModelForSeq2SeqLM.from_pretrained(self.model_name, from_pt=True)
|
||||
return model
|
||||
|
||||
@slow
|
||||
def test_90_generation_from_long_input(self):
|
||||
model_inputs = self.tokenizer(self.src_text, return_tensors="tf")
|
||||
generated_ids = self.model.generate(
|
||||
model_inputs.input_ids,
|
||||
attention_mask=model_inputs.attention_mask,
|
||||
num_beams=2,
|
||||
use_cache=True,
|
||||
)
|
||||
generated_words = self.tokenizer.batch_decode(generated_ids.numpy(), skip_special_tokens=True)[0]
|
||||
assert generated_words in (
|
||||
"i don't know. i just feel like i'm going to throw up. it's not fun.",
|
||||
"i'm not sure. i just feel like i've been feeling like i have to be in a certain place",
|
||||
"i'm not sure. i just feel like i've been in a bad situation.",
|
||||
)
|
|
@ -0,0 +1,197 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2020 HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import tempfile
|
||||
import unittest
|
||||
import warnings
|
||||
|
||||
from transformers import AutoTokenizer, MarianConfig, MarianTokenizer, TranslationPipeline, is_tf_available
|
||||
from transformers.file_utils import cached_property
|
||||
from transformers.testing_utils import is_pt_tf_cross_test, require_sentencepiece, require_tf, require_tokenizers, slow
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_modeling_tf_bart import TFBartModelTester
|
||||
from .test_modeling_tf_common import TFModelTesterMixin
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
from transformers import TFAutoModelForSeq2SeqLM, TFMarianMTModel
|
||||
|
||||
|
||||
class ModelTester(TFBartModelTester):
|
||||
config_updates = dict(static_position_embeddings=True, add_bias_logits=True)
|
||||
config_cls = MarianConfig
|
||||
|
||||
|
||||
@require_tf
|
||||
class TestTFMarianCommon(TFModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (TFMarianMTModel,) if is_tf_available() else ()
|
||||
all_generative_model_classes = (TFMarianMTModel,) if is_tf_available() else ()
|
||||
model_tester_cls = ModelTester
|
||||
is_encoder_decoder = True
|
||||
test_pruning = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = self.model_tester_cls(self)
|
||||
self.config_tester = ConfigTester(self, config_class=MarianConfig)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_inputs_embeds(self):
|
||||
# inputs_embeds not supported
|
||||
pass
|
||||
|
||||
def test_saved_model_with_hidden_states_output(self):
|
||||
# Should be uncommented during patrick TF refactor
|
||||
pass
|
||||
|
||||
def test_saved_model_with_attentions_output(self):
|
||||
pass
|
||||
|
||||
def test_compile_tf_model(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08, clipnorm=1.0)
|
||||
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
|
||||
metric = tf.keras.metrics.SparseCategoricalAccuracy("accuracy")
|
||||
|
||||
model_class = self.all_generative_model_classes[0]
|
||||
input_ids = {
|
||||
"decoder_input_ids": tf.keras.Input(batch_shape=(2, 2000), name="decoder_input_ids", dtype="int32"),
|
||||
"input_ids": tf.keras.Input(batch_shape=(2, 2000), name="input_ids", dtype="int32"),
|
||||
}
|
||||
|
||||
# Prepare our model
|
||||
model = model_class(config)
|
||||
model(self._prepare_for_class(inputs_dict, model_class)) # Model must be called before saving.
|
||||
# Let's load it from the disk to be sure we can use pre-trained weights
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model = model_class.from_pretrained(tmpdirname)
|
||||
|
||||
outputs_dict = model(input_ids)
|
||||
hidden_states = outputs_dict[0]
|
||||
|
||||
# Add a dense layer on top to test integration with other keras modules
|
||||
outputs = tf.keras.layers.Dense(2, activation="softmax", name="outputs")(hidden_states)
|
||||
|
||||
# Compile extended model
|
||||
extended_model = tf.keras.Model(inputs=[input_ids], outputs=[outputs])
|
||||
extended_model.compile(optimizer=optimizer, loss=loss, metrics=[metric])
|
||||
|
||||
|
||||
class AbstractMarianIntegrationTest(unittest.TestCase):
|
||||
maxDiff = 1000 # show more chars for failing integration tests
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
cls.model_name = f"Helsinki-NLP/opus-mt-{cls.src}-{cls.tgt}"
|
||||
return cls
|
||||
|
||||
@cached_property
|
||||
def tokenizer(self) -> MarianTokenizer:
|
||||
return AutoTokenizer.from_pretrained(self.model_name)
|
||||
|
||||
@property
|
||||
def eos_token_id(self) -> int:
|
||||
return self.tokenizer.eos_token_id
|
||||
|
||||
@cached_property
|
||||
def model(self):
|
||||
warnings.simplefilter("error")
|
||||
model: TFMarianMTModel = TFAutoModelForSeq2SeqLM.from_pretrained(self.model_name, from_pt=True)
|
||||
assert isinstance(model, TFMarianMTModel)
|
||||
c = model.config
|
||||
self.assertListEqual(c.bad_words_ids, [[c.pad_token_id]])
|
||||
self.assertEqual(c.max_length, 512)
|
||||
self.assertEqual(c.decoder_start_token_id, c.pad_token_id)
|
||||
return model
|
||||
|
||||
def _assert_generated_batch_equal_expected(self, **tokenizer_kwargs):
|
||||
generated_words = self.translate_src_text(**tokenizer_kwargs)
|
||||
self.assertListEqual(self.expected_text, generated_words)
|
||||
|
||||
def translate_src_text(self, **tokenizer_kwargs):
|
||||
model_inputs = self.tokenizer.prepare_seq2seq_batch(
|
||||
src_texts=self.src_text, **tokenizer_kwargs, return_tensors="tf"
|
||||
)
|
||||
generated_ids = self.model.generate(
|
||||
model_inputs.input_ids, attention_mask=model_inputs.attention_mask, num_beams=2, max_length=128
|
||||
)
|
||||
generated_words = self.tokenizer.batch_decode(generated_ids.numpy(), skip_special_tokens=True)
|
||||
return generated_words
|
||||
|
||||
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
@is_pt_tf_cross_test
|
||||
class TestMarian_MT_EN(AbstractMarianIntegrationTest):
|
||||
"""Cover low resource/high perplexity setting. This breaks if pad_token_id logits not set to LARGE_NEGATIVE."""
|
||||
|
||||
src = "mt"
|
||||
tgt = "en"
|
||||
src_text = ["Billi messu b'mod ġentili, Ġesù fejjaq raġel li kien milqut bil - marda kerha tal - ġdiem."]
|
||||
expected_text = ["Touching gently, Jesus healed a man who was affected by the sad disease of leprosy."]
|
||||
|
||||
@slow
|
||||
def test_batch_generation_mt_en(self):
|
||||
self._assert_generated_batch_equal_expected()
|
||||
|
||||
|
||||
@is_pt_tf_cross_test
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
class TestMarian_en_zh(AbstractMarianIntegrationTest):
|
||||
src = "en"
|
||||
tgt = "zh"
|
||||
src_text = ["My name is Wolfgang and I live in Berlin"]
|
||||
expected_text = ["我叫沃尔夫冈 我住在柏林"]
|
||||
|
||||
@slow
|
||||
def test_batch_generation_en_zh(self):
|
||||
self._assert_generated_batch_equal_expected()
|
||||
|
||||
|
||||
@is_pt_tf_cross_test
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
class TestMarian_en_ROMANCE(AbstractMarianIntegrationTest):
|
||||
"""Multilingual on target side."""
|
||||
|
||||
src = "en"
|
||||
tgt = "ROMANCE"
|
||||
src_text = [
|
||||
">>fr<< Don't spend so much time watching TV.",
|
||||
">>pt<< Your message has been sent.",
|
||||
">>es<< He's two years older than me.",
|
||||
]
|
||||
expected_text = [
|
||||
"Ne passez pas autant de temps à regarder la télé.",
|
||||
"A sua mensagem foi enviada.",
|
||||
"Es dos años más viejo que yo.",
|
||||
]
|
||||
|
||||
@slow
|
||||
def test_batch_generation_en_ROMANCE_multi(self):
|
||||
self._assert_generated_batch_equal_expected()
|
||||
|
||||
@slow
|
||||
def test_pipeline(self):
|
||||
pipeline = TranslationPipeline(self.model, self.tokenizer, framework="tf")
|
||||
output = pipeline(self.src_text)
|
||||
self.assertEqual(self.expected_text, [x["translation_text"] for x in output])
|
|
@ -0,0 +1,134 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2020 HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from tests.test_configuration_common import ConfigTester
|
||||
from tests.test_modeling_tf_bart import TFBartModelTester
|
||||
from tests.test_modeling_tf_common import TFModelTesterMixin
|
||||
from transformers import AutoTokenizer, MBartConfig, is_tf_available
|
||||
from transformers.file_utils import cached_property
|
||||
from transformers.testing_utils import is_pt_tf_cross_test, require_sentencepiece, require_tf, require_tokenizers, slow
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from transformers import TFAutoModelForSeq2SeqLM, TFMBartForConditionalGeneration
|
||||
|
||||
|
||||
class ModelTester(TFBartModelTester):
|
||||
config_updates = dict(normalize_before=True, add_final_layer_norm=True)
|
||||
config_cls = MBartConfig
|
||||
|
||||
|
||||
@require_tf
|
||||
class TestTFMBartCommon(TFModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (TFMBartForConditionalGeneration,) if is_tf_available() else ()
|
||||
all_generative_model_classes = (TFMBartForConditionalGeneration,) if is_tf_available() else ()
|
||||
model_tester_cls = ModelTester
|
||||
is_encoder_decoder = True
|
||||
test_pruning = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = self.model_tester_cls(self)
|
||||
self.config_tester = ConfigTester(self, config_class=MBartConfig)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_inputs_embeds(self):
|
||||
# inputs_embeds not supported
|
||||
pass
|
||||
|
||||
def test_saved_model_with_hidden_states_output(self):
|
||||
# Should be uncommented during patrick TF refactor
|
||||
pass
|
||||
|
||||
def test_saved_model_with_attentions_output(self):
|
||||
# Should be uncommented during patrick TF refactor
|
||||
pass
|
||||
|
||||
def test_compile_tf_model(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08, clipnorm=1.0)
|
||||
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
|
||||
metric = tf.keras.metrics.SparseCategoricalAccuracy("accuracy")
|
||||
|
||||
model_class = self.all_generative_model_classes[0]
|
||||
input_ids = {
|
||||
"decoder_input_ids": tf.keras.Input(batch_shape=(2, 2000), name="decoder_input_ids", dtype="int32"),
|
||||
"input_ids": tf.keras.Input(batch_shape=(2, 2000), name="input_ids", dtype="int32"),
|
||||
}
|
||||
|
||||
# Prepare our model
|
||||
model = model_class(config)
|
||||
model(self._prepare_for_class(inputs_dict, model_class)) # Model must be called before saving.
|
||||
# Let's load it from the disk to be sure we can use pretrained weights
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model = model_class.from_pretrained(tmpdirname)
|
||||
|
||||
outputs_dict = model(input_ids)
|
||||
hidden_states = outputs_dict[0]
|
||||
|
||||
# Add a dense layer on top to test integration with other keras modules
|
||||
outputs = tf.keras.layers.Dense(2, activation="softmax", name="outputs")(hidden_states)
|
||||
|
||||
# Compile extended model
|
||||
extended_model = tf.keras.Model(inputs=[input_ids], outputs=[outputs])
|
||||
extended_model.compile(optimizer=optimizer, loss=loss, metrics=[metric])
|
||||
|
||||
|
||||
@is_pt_tf_cross_test
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
class TestMBartEnRO(unittest.TestCase):
|
||||
src_text = [
|
||||
" UN Chief Says There Is No Military Solution in Syria",
|
||||
]
|
||||
expected_text = [
|
||||
"Şeful ONU declară că nu există o soluţie militară în Siria",
|
||||
]
|
||||
model_name = "facebook/mbart-large-en-ro"
|
||||
|
||||
@cached_property
|
||||
def tokenizer(self):
|
||||
return AutoTokenizer.from_pretrained(self.model_name)
|
||||
|
||||
@cached_property
|
||||
def model(self):
|
||||
model = TFAutoModelForSeq2SeqLM.from_pretrained(self.model_name, from_pt=True)
|
||||
return model
|
||||
|
||||
def _assert_generated_batch_equal_expected(self, **tokenizer_kwargs):
|
||||
generated_words = self.translate_src_text(**tokenizer_kwargs)
|
||||
self.assertListEqual(self.expected_text, generated_words)
|
||||
|
||||
def translate_src_text(self, **tokenizer_kwargs):
|
||||
model_inputs = self.tokenizer.prepare_seq2seq_batch(
|
||||
src_texts=self.src_text, **tokenizer_kwargs, return_tensors="tf"
|
||||
)
|
||||
generated_ids = self.model.generate(
|
||||
model_inputs.input_ids, attention_mask=model_inputs.attention_mask, num_beams=2
|
||||
)
|
||||
generated_words = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
return generated_words
|
||||
|
||||
@slow
|
||||
def test_batch_generation_en_ro(self):
|
||||
self._assert_generated_batch_equal_expected()
|
|
@ -0,0 +1,141 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2020 HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import AutoTokenizer, PegasusConfig, is_tf_available
|
||||
from transformers.file_utils import cached_property
|
||||
from transformers.testing_utils import is_pt_tf_cross_test, require_sentencepiece, require_tf, require_tokenizers, slow
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_modeling_pegasus import PGE_ARTICLE, XSUM_ENTRY_LONGER
|
||||
from .test_modeling_tf_bart import TFBartModelTester
|
||||
from .test_modeling_tf_common import TFModelTesterMixin
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
from transformers import TFAutoModelForSeq2SeqLM, TFPegasusForConditionalGeneration
|
||||
|
||||
|
||||
class ModelTester(TFBartModelTester):
|
||||
config_updates = dict(
|
||||
normalize_before=True,
|
||||
static_position_embeddings=True,
|
||||
)
|
||||
hidden_act = "relu"
|
||||
config_cls = PegasusConfig
|
||||
|
||||
|
||||
@require_tf
|
||||
class TestTFPegasusCommon(TFModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (TFPegasusForConditionalGeneration,) if is_tf_available() else ()
|
||||
all_generative_model_classes = (TFPegasusForConditionalGeneration,) if is_tf_available() else ()
|
||||
model_tester_cls = ModelTester
|
||||
is_encoder_decoder = True
|
||||
test_pruning = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = self.model_tester_cls(self)
|
||||
self.config_tester = ConfigTester(self, config_class=PegasusConfig)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_inputs_embeds(self):
|
||||
# inputs_embeds not supported
|
||||
pass
|
||||
|
||||
def test_saved_model_with_hidden_states_output(self):
|
||||
# Should be uncommented during patrick TF refactor
|
||||
pass
|
||||
|
||||
def test_saved_model_with_attentions_output(self):
|
||||
# Should be uncommented during patrick TF refactor
|
||||
pass
|
||||
|
||||
def test_compile_tf_model(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08, clipnorm=1.0)
|
||||
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
|
||||
metric = tf.keras.metrics.SparseCategoricalAccuracy("accuracy")
|
||||
|
||||
model_class = self.all_generative_model_classes[0]
|
||||
input_ids = {
|
||||
"decoder_input_ids": tf.keras.Input(batch_shape=(2, 2000), name="decoder_input_ids", dtype="int32"),
|
||||
"input_ids": tf.keras.Input(batch_shape=(2, 2000), name="input_ids", dtype="int32"),
|
||||
}
|
||||
|
||||
# Prepare our model
|
||||
model = model_class(config)
|
||||
model(self._prepare_for_class(inputs_dict, model_class)) # Model must be called before saving.
|
||||
# Let's load it from the disk to be sure we can use pretrained weights
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model = model_class.from_pretrained(tmpdirname)
|
||||
|
||||
outputs_dict = model(input_ids)
|
||||
hidden_states = outputs_dict[0]
|
||||
|
||||
# Add a dense layer on top to test integration with other keras modules
|
||||
outputs = tf.keras.layers.Dense(2, activation="softmax", name="outputs")(hidden_states)
|
||||
|
||||
# Compile extended model
|
||||
extended_model = tf.keras.Model(inputs=[input_ids], outputs=[outputs])
|
||||
extended_model.compile(optimizer=optimizer, loss=loss, metrics=[metric])
|
||||
|
||||
|
||||
@is_pt_tf_cross_test
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
class TFPegasusIntegrationTests(unittest.TestCase):
|
||||
src_text = [PGE_ARTICLE, XSUM_ENTRY_LONGER]
|
||||
expected_text = [
|
||||
"California's largest electricity provider has cut power to hundreds of thousands of customers in an effort to reduce the risk of wildfires.",
|
||||
'N-Dubz have revealed they\'re "grateful" to have been nominated for four Mobo Awards.',
|
||||
] # differs slightly from pytorch, likely due to numerical differences in linear layers
|
||||
model_name = "google/pegasus-xsum"
|
||||
|
||||
@cached_property
|
||||
def tokenizer(self):
|
||||
return AutoTokenizer.from_pretrained(self.model_name)
|
||||
|
||||
@cached_property
|
||||
def model(self):
|
||||
model = TFAutoModelForSeq2SeqLM.from_pretrained(self.model_name, from_pt=True)
|
||||
return model
|
||||
|
||||
def _assert_generated_batch_equal_expected(self, **tokenizer_kwargs):
|
||||
generated_words = self.translate_src_text(**tokenizer_kwargs)
|
||||
assert self.expected_text == generated_words
|
||||
|
||||
def translate_src_text(self, **tokenizer_kwargs):
|
||||
model_inputs = self.tokenizer.prepare_seq2seq_batch(
|
||||
src_texts=self.src_text, **tokenizer_kwargs, return_tensors="tf"
|
||||
)
|
||||
generated_ids = self.model.generate(
|
||||
model_inputs.input_ids,
|
||||
attention_mask=model_inputs.attention_mask,
|
||||
num_beams=2,
|
||||
use_cache=True,
|
||||
)
|
||||
generated_words = self.tokenizer.batch_decode(generated_ids.numpy(), skip_special_tokens=True)
|
||||
return generated_words
|
||||
|
||||
@slow
|
||||
def test_batch_generation(self):
|
||||
self._assert_generated_batch_equal_expected()
|
|
@ -67,6 +67,7 @@ MODEL_NAME_TO_DOC_FILE = {
|
|||
"xlm_prophetnet": "xlmprophetnet.rst",
|
||||
"xlm_roberta": "xlmroberta.rst",
|
||||
"bert_generation": "bertgeneration.rst",
|
||||
"marian": "marian.rst",
|
||||
}
|
||||
|
||||
# This is to make sure the transformers module imported is the one in the repo.
|
||||
|
@ -148,7 +149,6 @@ def get_model_doc_files():
|
|||
_ignore_modules = [
|
||||
"auto",
|
||||
"dialogpt",
|
||||
"marian",
|
||||
"retribert",
|
||||
]
|
||||
doc_files = []
|
||||
|
@ -245,6 +245,7 @@ def check_models_are_documented(module, doc_file):
|
|||
def _get_model_name(module):
|
||||
""" Get the model name for the module defining it."""
|
||||
splits = module.__name__.split("_")
|
||||
|
||||
# Secial case for transfo_xl
|
||||
if splits[-1] == "xl":
|
||||
return "_".join(splits[-2:])
|
||||
|
|
Loading…
Reference in New Issue