Encoder-decoder models: move embedding scale to nn.Module (#30410)
* move scaling to nn.Module
* let the test be here for now (need to fix)
* failing tests
* last failing models
* Revert commit 4c14817f38
* clean-up
* oops forgot
* codestyle
* raise NotImplemented when possible
* Update tests/test_modeling_common.py
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* skip tests in respective modeling files
---------
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
parent
9d31b32e9d
commit
38a4bf79ad
|
@ -132,6 +132,19 @@ class BartLearnedPositionalEmbedding(nn.Embedding):
|
|||
return super().forward(positions + self.offset)
|
||||
|
||||
|
||||
class BartScaledWordEmbedding(nn.Embedding):
|
||||
"""
|
||||
This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
|
||||
"""
|
||||
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0):
|
||||
super().__init__(num_embeddings, embedding_dim, padding_idx)
|
||||
self.embed_scale = embed_scale
|
||||
|
||||
def forward(self, input_ids: torch.Tensor):
|
||||
return super().forward(input_ids) * self.embed_scale
|
||||
|
||||
|
||||
class BartAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
|
@ -1056,9 +1069,11 @@ class BartEncoder(BartPreTrainedModel):
|
|||
embed_dim = config.d_model
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.max_source_positions = config.max_position_embeddings
|
||||
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
||||
embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
|
||||
self.embed_tokens = BartScaledWordEmbedding(
|
||||
config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
|
||||
)
|
||||
|
||||
if embed_tokens is not None:
|
||||
self.embed_tokens.weight = embed_tokens.weight
|
||||
|
@ -1146,7 +1161,7 @@ class BartEncoder(BartPreTrainedModel):
|
|||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
embed_pos = self.embed_positions(input)
|
||||
embed_pos = embed_pos.to(inputs_embeds.device)
|
||||
|
@ -1238,9 +1253,11 @@ class BartDecoder(BartPreTrainedModel):
|
|||
self.layerdrop = config.decoder_layerdrop
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.max_target_positions = config.max_position_embeddings
|
||||
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
||||
embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
|
||||
self.embed_tokens = BartScaledWordEmbedding(
|
||||
config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
|
||||
)
|
||||
|
||||
if embed_tokens is not None:
|
||||
self.embed_tokens.weight = embed_tokens.weight
|
||||
|
@ -1369,7 +1386,7 @@ class BartDecoder(BartPreTrainedModel):
|
|||
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input) * self.embed_scale
|
||||
inputs_embeds = self.embed_tokens(input)
|
||||
|
||||
if self._use_flash_attention_2:
|
||||
# 2d mask is passed through the layers
|
||||
|
|
|
@ -90,6 +90,20 @@ class BigBirdPegasusLearnedPositionalEmbedding(nn.Embedding):
|
|||
return super().forward(positions)
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->BigBirdPegasus
|
||||
class BigBirdPegasusScaledWordEmbedding(nn.Embedding):
|
||||
"""
|
||||
This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
|
||||
"""
|
||||
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0):
|
||||
super().__init__(num_embeddings, embedding_dim, padding_idx)
|
||||
self.embed_scale = embed_scale
|
||||
|
||||
def forward(self, input_ids: torch.Tensor):
|
||||
return super().forward(input_ids) * self.embed_scale
|
||||
|
||||
|
||||
# Copied from transformers.models.big_bird.modeling_big_bird.BigBirdSelfAttention with BigBird->BigBirdPegasus
|
||||
class BigBirdPegasusSelfAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
|
@ -1749,9 +1763,11 @@ class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel):
|
|||
embed_dim = config.d_model
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.max_source_positions = config.max_position_embeddings
|
||||
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
||||
embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
|
||||
self.embed_tokens = BigBirdPegasusScaledWordEmbedding(
|
||||
config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
|
||||
)
|
||||
|
||||
if embed_tokens is not None:
|
||||
self.embed_tokens.weight = embed_tokens.weight
|
||||
|
@ -1827,7 +1843,7 @@ class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel):
|
|||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
embed_pos = self.embed_positions(input_shape)
|
||||
|
||||
|
@ -2042,9 +2058,11 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
|
|||
self.layerdrop = config.decoder_layerdrop
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.max_target_positions = config.max_position_embeddings
|
||||
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
||||
embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
|
||||
self.embed_tokens = BigBirdPegasusScaledWordEmbedding(
|
||||
config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
|
||||
)
|
||||
|
||||
if embed_tokens is not None:
|
||||
self.embed_tokens.weight = embed_tokens.weight
|
||||
|
@ -2168,7 +2186,7 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
|
|||
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||||
|
@ -2292,7 +2310,10 @@ class BigBirdPegasusModel(BigBirdPegasusPreTrainedModel):
|
|||
super().__init__(config)
|
||||
|
||||
padding_idx, vocab_size = config.pad_token_id, config.vocab_size
|
||||
self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
|
||||
embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
||||
self.shared = BigBirdPegasusScaledWordEmbedding(
|
||||
vocab_size, config.d_model, padding_idx, embed_scale=embed_scale
|
||||
)
|
||||
|
||||
self.encoder = BigBirdPegasusEncoder(config, self.shared)
|
||||
self.decoder = BigBirdPegasusDecoder(config, self.shared)
|
||||
|
|
|
@ -75,6 +75,20 @@ class BioGptLearnedPositionalEmbedding(nn.Embedding):
|
|||
return super().forward(positions + self.offset)
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->BioGpt
|
||||
class BioGptScaledWordEmbedding(nn.Embedding):
|
||||
"""
|
||||
This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
|
||||
"""
|
||||
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0):
|
||||
super().__init__(num_embeddings, embedding_dim, padding_idx)
|
||||
self.embed_scale = embed_scale
|
||||
|
||||
def forward(self, input_ids: torch.Tensor):
|
||||
return super().forward(input_ids) * self.embed_scale
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->BioGpt
|
||||
class BioGptAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
@ -423,9 +437,11 @@ class BioGptModel(BioGptPreTrainedModel):
|
|||
self.dropout = config.hidden_dropout_prob
|
||||
self.embed_dim = config.hidden_size
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0
|
||||
embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, self.embed_dim, self.padding_idx)
|
||||
self.embed_tokens = BioGptScaledWordEmbedding(
|
||||
config.vocab_size, self.embed_dim, self.padding_idx, embed_scale=embed_scale
|
||||
)
|
||||
self.embed_positions = BioGptLearnedPositionalEmbedding(config.max_position_embeddings, self.embed_dim)
|
||||
|
||||
self.layers = nn.ModuleList([BioGptDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
|
@ -482,7 +498,7 @@ class BioGptModel(BioGptPreTrainedModel):
|
|||
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input) * self.embed_scale
|
||||
inputs_embeds = self.embed_tokens(input)
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(
|
||||
|
|
|
@ -90,6 +90,20 @@ class BlenderbotLearnedPositionalEmbedding(nn.Embedding):
|
|||
return super().forward(positions)
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->Blenderbot
|
||||
class BlenderbotScaledWordEmbedding(nn.Embedding):
|
||||
"""
|
||||
This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
|
||||
"""
|
||||
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0):
|
||||
super().__init__(num_embeddings, embedding_dim, padding_idx)
|
||||
self.embed_scale = embed_scale
|
||||
|
||||
def forward(self, input_ids: torch.Tensor):
|
||||
return super().forward(input_ids) * self.embed_scale
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Blenderbot
|
||||
class BlenderbotAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
@ -632,12 +646,14 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel):
|
|||
embed_dim = config.d_model
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.max_source_positions = config.max_position_embeddings
|
||||
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
||||
embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
||||
|
||||
if embed_tokens is not None:
|
||||
self.embed_tokens = embed_tokens
|
||||
else:
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
|
||||
self.embed_tokens = BlenderbotScaledWordEmbedding(
|
||||
config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
|
||||
)
|
||||
|
||||
self.embed_positions = BlenderbotLearnedPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
|
@ -715,7 +731,7 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel):
|
|||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
embed_pos = self.embed_positions(input_shape)
|
||||
|
||||
|
@ -799,12 +815,14 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
|
|||
self.layerdrop = config.decoder_layerdrop
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.max_target_positions = config.max_position_embeddings
|
||||
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
||||
embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
||||
|
||||
if embed_tokens is not None:
|
||||
self.embed_tokens = embed_tokens
|
||||
else:
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
|
||||
self.embed_tokens = BlenderbotScaledWordEmbedding(
|
||||
config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
|
||||
)
|
||||
|
||||
self.embed_positions = BlenderbotLearnedPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
|
@ -926,7 +944,7 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
|
|||
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||||
|
|
|
@ -1325,6 +1325,11 @@ class BridgeTowerModel(BridgeTowerPreTrainedModel):
|
|||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
|
||||
if inputs_embeds is not None and input_ids is None:
|
||||
raise NotImplementedError(
|
||||
"BridgeTowerModel does not use `inputs_embeds`. Make sure to pass in `input_ids` instead."
|
||||
)
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
image_token_type_idx = image_token_type_idx if image_token_type_idx else 1
|
||||
input_shape = input_ids.size()
|
||||
|
|
|
@ -972,8 +972,7 @@ class FunnelBaseModel(FunnelPreTrainedModel):
|
|||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||
|
||||
# TODO: deal with head_mask
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embeddings(input_ids)
|
||||
inputs_embeds = self.embeddings(input_ids, inputs_embeds=inputs_embeds)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
inputs_embeds,
|
||||
|
@ -1048,8 +1047,7 @@ class FunnelModel(FunnelPreTrainedModel):
|
|||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||
|
||||
# TODO: deal with head_mask
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embeddings(input_ids)
|
||||
inputs_embeds = self.embeddings(input_ids, inputs_embeds=inputs_embeds)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
inputs_embeds,
|
||||
|
|
|
@ -920,6 +920,10 @@ class GPTSanJapaneseModel(GPTSanJapanesePreTrainedModel):
|
|||
device = self.position_embeddings.weight.device
|
||||
if input_ids is None:
|
||||
input_ids = torch.zeros([1, 1]).int().to(device) # dummy for input_ids was None
|
||||
if inputs_embeds is not None:
|
||||
raise NotImplementedError(
|
||||
"GPTSanJapaneseModel does not use `inputs_embeds`. Make sure to pass in `input_ids` instead."
|
||||
)
|
||||
num_pasts_contexts = 0
|
||||
num_batch = input_ids.shape[0]
|
||||
pasts_or_spout_value = None
|
||||
|
|
|
@ -87,6 +87,20 @@ def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_l
|
|||
return incremental_indices.long() + padding_idx
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->M2M100
|
||||
class M2M100ScaledWordEmbedding(nn.Embedding):
|
||||
"""
|
||||
This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
|
||||
"""
|
||||
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0):
|
||||
super().__init__(num_embeddings, embedding_dim, padding_idx)
|
||||
self.embed_scale = embed_scale
|
||||
|
||||
def forward(self, input_ids: torch.Tensor):
|
||||
return super().forward(input_ids) * self.embed_scale
|
||||
|
||||
|
||||
class M2M100SinusoidalPositionalEmbedding(nn.Module):
|
||||
"""This module produces sinusoidal positional embeddings of any length."""
|
||||
|
||||
|
@ -886,9 +900,11 @@ class M2M100Encoder(M2M100PreTrainedModel):
|
|||
embed_dim = config.d_model
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.max_source_positions = config.max_position_embeddings
|
||||
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
||||
embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
|
||||
self.embed_tokens = M2M100ScaledWordEmbedding(
|
||||
config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
|
||||
)
|
||||
|
||||
if embed_tokens is not None:
|
||||
self.embed_tokens.weight = embed_tokens.weight
|
||||
|
@ -971,7 +987,7 @@ class M2M100Encoder(M2M100PreTrainedModel):
|
|||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
embed_pos = self.embed_positions(input_ids, inputs_embeds)
|
||||
embed_pos = embed_pos.to(inputs_embeds.device)
|
||||
|
@ -1061,9 +1077,11 @@ class M2M100Decoder(M2M100PreTrainedModel):
|
|||
self.layerdrop = config.decoder_layerdrop
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.max_target_positions = config.max_position_embeddings
|
||||
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
||||
embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
|
||||
self.embed_tokens = M2M100ScaledWordEmbedding(
|
||||
config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
|
||||
)
|
||||
|
||||
if embed_tokens is not None:
|
||||
self.embed_tokens.weight = embed_tokens.weight
|
||||
|
@ -1183,7 +1201,7 @@ class M2M100Decoder(M2M100PreTrainedModel):
|
|||
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if self._use_flash_attention_2:
|
||||
# 2d mask is passed through the layers
|
||||
|
@ -1321,7 +1339,8 @@ class M2M100Model(M2M100PreTrainedModel):
|
|||
super().__init__(config)
|
||||
|
||||
padding_idx, vocab_size = config.pad_token_id, config.vocab_size
|
||||
self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
|
||||
embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
||||
self.shared = M2M100ScaledWordEmbedding(vocab_size, config.d_model, padding_idx, embed_scale=embed_scale)
|
||||
|
||||
self.encoder = M2M100Encoder(config, self.shared)
|
||||
self.decoder = M2M100Decoder(config, self.shared)
|
||||
|
|
|
@ -118,6 +118,20 @@ class MBartLearnedPositionalEmbedding(nn.Embedding):
|
|||
return super().forward(positions + self.offset)
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->MBart
|
||||
class MBartScaledWordEmbedding(nn.Embedding):
|
||||
"""
|
||||
This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
|
||||
"""
|
||||
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0):
|
||||
super().__init__(num_embeddings, embedding_dim, padding_idx)
|
||||
self.embed_scale = embed_scale
|
||||
|
||||
def forward(self, input_ids: torch.Tensor):
|
||||
return super().forward(input_ids) * self.embed_scale
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->MBart
|
||||
class MBartAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
@ -919,9 +933,11 @@ class MBartEncoder(MBartPreTrainedModel):
|
|||
embed_dim = config.d_model
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.max_source_positions = config.max_position_embeddings
|
||||
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
||||
embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
|
||||
self.embed_tokens = MBartScaledWordEmbedding(
|
||||
config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
|
||||
)
|
||||
|
||||
if embed_tokens is not None:
|
||||
self.embed_tokens.weight = embed_tokens.weight
|
||||
|
@ -1009,7 +1025,7 @@ class MBartEncoder(MBartPreTrainedModel):
|
|||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
embed_pos = self.embed_positions(input)
|
||||
|
||||
|
@ -1097,9 +1113,11 @@ class MBartDecoder(MBartPreTrainedModel):
|
|||
self.layerdrop = config.decoder_layerdrop
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.max_target_positions = config.max_position_embeddings
|
||||
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
||||
embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
|
||||
self.embed_tokens = MBartScaledWordEmbedding(
|
||||
config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
|
||||
)
|
||||
|
||||
if embed_tokens is not None:
|
||||
self.embed_tokens.weight = embed_tokens.weight
|
||||
|
@ -1227,7 +1245,7 @@ class MBartDecoder(MBartPreTrainedModel):
|
|||
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if self._use_flash_attention_2:
|
||||
# 2d mask is passed through the layers
|
||||
|
|
|
@ -133,6 +133,20 @@ def load_balancing_loss_func(router_probs: torch.Tensor, expert_indices: torch.T
|
|||
return torch.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert) * (num_experts**2)
|
||||
|
||||
|
||||
# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100ScaledWordEmbedding with M2M100->NllbMoe
|
||||
class NllbMoeScaledWordEmbedding(nn.Embedding):
|
||||
"""
|
||||
This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
|
||||
"""
|
||||
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0):
|
||||
super().__init__(num_embeddings, embedding_dim, padding_idx)
|
||||
self.embed_scale = embed_scale
|
||||
|
||||
def forward(self, input_ids: torch.Tensor):
|
||||
return super().forward(input_ids) * self.embed_scale
|
||||
|
||||
|
||||
# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding
|
||||
class NllbMoeSinusoidalPositionalEmbedding(nn.Module):
|
||||
"""This module produces sinusoidal positional embeddings of any length."""
|
||||
|
@ -992,9 +1006,11 @@ class NllbMoeEncoder(NllbMoePreTrainedModel):
|
|||
embed_dim = config.d_model
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.max_source_positions = config.max_position_embeddings
|
||||
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
||||
embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
|
||||
self.embed_tokens = NllbMoeScaledWordEmbedding(
|
||||
config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
|
||||
)
|
||||
|
||||
if embed_tokens is not None:
|
||||
self.embed_tokens.weight = embed_tokens.weight
|
||||
|
@ -1085,7 +1101,7 @@ class NllbMoeEncoder(NllbMoePreTrainedModel):
|
|||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
embed_pos = self.embed_positions(input_ids, inputs_embeds)
|
||||
embed_pos = embed_pos.to(inputs_embeds.device)
|
||||
|
@ -1178,9 +1194,11 @@ class NllbMoeDecoder(NllbMoePreTrainedModel):
|
|||
self.layerdrop = config.decoder_layerdrop
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.max_target_positions = config.max_position_embeddings
|
||||
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
||||
embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
|
||||
self.embed_tokens = NllbMoeScaledWordEmbedding(
|
||||
config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
|
||||
)
|
||||
|
||||
if embed_tokens is not None:
|
||||
self.embed_tokens.weight = embed_tokens.weight
|
||||
|
@ -1309,7 +1327,7 @@ class NllbMoeDecoder(NllbMoePreTrainedModel):
|
|||
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
# create causal mask
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
|
@ -1458,7 +1476,8 @@ class NllbMoeModel(NllbMoePreTrainedModel):
|
|||
super().__init__(config)
|
||||
|
||||
padding_idx, vocab_size = config.pad_token_id, config.vocab_size
|
||||
self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
|
||||
embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
||||
self.shared = NllbMoeScaledWordEmbedding(vocab_size, config.d_model, padding_idx, embed_scale=embed_scale)
|
||||
|
||||
self.encoder = NllbMoeEncoder(config, self.shared)
|
||||
self.decoder = NllbMoeDecoder(config, self.shared)
|
||||
|
|
|
@ -87,6 +87,20 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start
|
|||
return shifted_input_ids
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->PegasusX
|
||||
class PegasusXScaledWordEmbedding(nn.Embedding):
|
||||
"""
|
||||
This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
|
||||
"""
|
||||
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0):
|
||||
super().__init__(num_embeddings, embedding_dim, padding_idx)
|
||||
self.embed_scale = embed_scale
|
||||
|
||||
def forward(self, input_ids: torch.Tensor):
|
||||
return super().forward(input_ids) * self.embed_scale
|
||||
|
||||
|
||||
class PegasusXSinusoidalPositionalEmbedding(nn.Module):
|
||||
"""This module produces sinusoidal positional embeddings of any length."""
|
||||
|
||||
|
@ -880,13 +894,16 @@ class PegasusXEncoder(PegasusXPreTrainedModel):
|
|||
self.layerdrop = config.encoder_layerdrop
|
||||
|
||||
embed_dim = config.d_model
|
||||
padding_idx = config.pad_token_id
|
||||
self.max_source_positions = config.max_position_embeddings
|
||||
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
||||
embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
||||
|
||||
if embed_tokens is not None:
|
||||
self.embed_tokens = embed_tokens
|
||||
else:
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim)
|
||||
self.embed_tokens = PegasusXScaledWordEmbedding(
|
||||
config.vocab_size, embed_dim, padding_idx, embed_scale=embed_scale
|
||||
)
|
||||
|
||||
self.embed_global = nn.Embedding(config.num_global_tokens, embed_dim)
|
||||
self.embed_positions = PegasusXSinusoidalPositionalEmbedding(embed_dim)
|
||||
|
@ -988,7 +1005,7 @@ class PegasusXEncoder(PegasusXPreTrainedModel):
|
|||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
embed_pos = self.embed_positions(inputs_embeds)
|
||||
|
||||
|
@ -1086,12 +1103,15 @@ class PegasusXDecoder(PegasusXPreTrainedModel):
|
|||
self.dropout = config.dropout
|
||||
self.layerdrop = config.decoder_layerdrop
|
||||
self.max_target_positions = config.max_position_embeddings
|
||||
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
||||
embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
||||
padding_idx = config.pad_token_id
|
||||
|
||||
if embed_tokens is not None:
|
||||
self.embed_tokens = embed_tokens
|
||||
else:
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model)
|
||||
self.embed_tokens = PegasusXScaledWordEmbedding(
|
||||
config.vocab_size, config.d_model, padding_idx=padding_idx, embed_scale=embed_scale
|
||||
)
|
||||
|
||||
self.embed_positions = PegasusXSinusoidalPositionalEmbedding(config.d_model)
|
||||
self.layers = nn.ModuleList([PegasusXDecoderLayer(config) for _ in range(config.decoder_layers)])
|
||||
|
@ -1196,7 +1216,7 @@ class PegasusXDecoder(PegasusXPreTrainedModel):
|
|||
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||||
|
@ -1307,7 +1327,11 @@ class PegasusXModel(PegasusXPreTrainedModel):
|
|||
super().__init__(config)
|
||||
|
||||
vocab_size = config.vocab_size
|
||||
self.shared = nn.Embedding(vocab_size, config.d_model)
|
||||
embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
||||
padding_idx = config.pad_token_id
|
||||
self.shared = PegasusXScaledWordEmbedding(
|
||||
vocab_size, config.d_model, padding_idx=padding_idx, embed_scale=embed_scale
|
||||
)
|
||||
|
||||
self.encoder = PegasusXEncoder(config, self.shared)
|
||||
self.decoder = PegasusXDecoder(config, self.shared)
|
||||
|
|
|
@ -102,6 +102,20 @@ class PLBartLearnedPositionalEmbedding(nn.Embedding):
|
|||
return super().forward(positions + self.offset)
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->PLBart
|
||||
class PLBartScaledWordEmbedding(nn.Embedding):
|
||||
"""
|
||||
This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
|
||||
"""
|
||||
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0):
|
||||
super().__init__(num_embeddings, embedding_dim, padding_idx)
|
||||
self.embed_scale = embed_scale
|
||||
|
||||
def forward(self, input_ids: torch.Tensor):
|
||||
return super().forward(input_ids) * self.embed_scale
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->PLBart
|
||||
class PLBartAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
@ -658,9 +672,11 @@ class PLBartEncoder(PLBartPreTrainedModel):
|
|||
embed_dim = config.d_model
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.max_source_positions = config.max_position_embeddings
|
||||
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
||||
embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
|
||||
self.embed_tokens = PLBartScaledWordEmbedding(
|
||||
config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
|
||||
)
|
||||
|
||||
if embed_tokens is not None:
|
||||
self.embed_tokens.weight = embed_tokens.weight
|
||||
|
@ -748,7 +764,7 @@ class PLBartEncoder(PLBartPreTrainedModel):
|
|||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
embed_pos = self.embed_positions(input)
|
||||
embed_pos = embed_pos.to(inputs_embeds.device)
|
||||
|
@ -841,9 +857,11 @@ class PLBartDecoder(PLBartPreTrainedModel):
|
|||
self.layerdrop = config.decoder_layerdrop
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.max_target_positions = config.max_position_embeddings
|
||||
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
||||
embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
|
||||
self.embed_tokens = PLBartScaledWordEmbedding(
|
||||
config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
|
||||
)
|
||||
|
||||
if embed_tokens is not None:
|
||||
self.embed_tokens.weight = embed_tokens.weight
|
||||
|
@ -972,7 +990,7 @@ class PLBartDecoder(PLBartPreTrainedModel):
|
|||
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input) * self.embed_scale
|
||||
inputs_embeds = self.embed_tokens(input)
|
||||
|
||||
if self._use_flash_attention_2:
|
||||
# 2d mask is passed through the layers
|
||||
|
@ -1122,7 +1140,8 @@ class PLBartModel(PLBartPreTrainedModel):
|
|||
super().__init__(config)
|
||||
|
||||
padding_idx, vocab_size = config.pad_token_id, config.vocab_size
|
||||
self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
|
||||
embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
||||
self.shared = PLBartScaledWordEmbedding(vocab_size, config.d_model, padding_idx, embed_scale=embed_scale)
|
||||
|
||||
self.encoder = PLBartEncoder(config, self.shared)
|
||||
self.decoder = PLBartDecoder(config, self.shared)
|
||||
|
|
|
@ -989,6 +989,20 @@ class SeamlessM4TConformerAdapter(nn.Module):
|
|||
############ TEXT / UNITS related code ################
|
||||
|
||||
|
||||
# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100ScaledWordEmbedding with M2M100->SeamlessM4T
|
||||
class SeamlessM4TScaledWordEmbedding(nn.Embedding):
|
||||
"""
|
||||
This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
|
||||
"""
|
||||
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0):
|
||||
super().__init__(num_embeddings, embedding_dim, padding_idx)
|
||||
self.embed_scale = embed_scale
|
||||
|
||||
def forward(self, input_ids: torch.Tensor):
|
||||
return super().forward(input_ids) * self.embed_scale
|
||||
|
||||
|
||||
# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding
|
||||
class SeamlessM4TSinusoidalPositionalEmbedding(nn.Module):
|
||||
"""This module produces sinusoidal positional embeddings of any length."""
|
||||
|
@ -1631,9 +1645,11 @@ class SeamlessM4TEncoder(SeamlessM4TPreTrainedModel):
|
|||
self.max_source_positions = config.max_position_embeddings
|
||||
|
||||
if not self.is_t2u_encoder:
|
||||
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
||||
embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
|
||||
self.embed_tokens = SeamlessM4TScaledWordEmbedding(
|
||||
config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
|
||||
)
|
||||
|
||||
if embed_tokens is not None:
|
||||
self.embed_tokens.weight = embed_tokens.weight
|
||||
|
@ -1726,7 +1742,7 @@ class SeamlessM4TEncoder(SeamlessM4TPreTrainedModel):
|
|||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if not self.is_t2u_encoder:
|
||||
embed_pos = self.embed_positions(input)
|
||||
|
@ -1809,14 +1825,18 @@ class SeamlessM4TDecoder(SeamlessM4TPreTrainedModel):
|
|||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
self.max_target_positions = config.max_position_embeddings
|
||||
self.embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0
|
||||
embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0
|
||||
|
||||
if embed_tokens is not None:
|
||||
# if embed_tokens defined, use its shape instead
|
||||
self.embed_tokens = nn.Embedding(embed_tokens.num_embeddings, embed_tokens.embedding_dim, self.padding_idx)
|
||||
self.embed_tokens = SeamlessM4TScaledWordEmbedding(
|
||||
embed_tokens.num_embeddings, embed_tokens.embedding_dim, self.padding_idx, embed_scale=embed_scale
|
||||
)
|
||||
self.embed_tokens.weight = embed_tokens.weight
|
||||
else:
|
||||
self.embed_tokens = nn.Embedding(self.vocab_size, config.hidden_size, self.padding_idx)
|
||||
self.embed_tokens = SeamlessM4TScaledWordEmbedding(
|
||||
self.vocab_size, config.hidden_size, self.padding_idx, embed_scale=embed_scale
|
||||
)
|
||||
|
||||
self.embed_positions = SeamlessM4TSinusoidalPositionalEmbedding(
|
||||
self.max_target_positions,
|
||||
|
@ -1935,7 +1955,7 @@ class SeamlessM4TDecoder(SeamlessM4TPreTrainedModel):
|
|||
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||||
|
|
|
@ -946,6 +946,20 @@ class SeamlessM4Tv2ConformerAdapter(nn.Module):
|
|||
############ TEXT / UNITS related code ################
|
||||
|
||||
|
||||
# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100ScaledWordEmbedding with M2M100->SeamlessM4Tv2
|
||||
class SeamlessM4Tv2ScaledWordEmbedding(nn.Embedding):
|
||||
"""
|
||||
This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
|
||||
"""
|
||||
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0):
|
||||
super().__init__(num_embeddings, embedding_dim, padding_idx)
|
||||
self.embed_scale = embed_scale
|
||||
|
||||
def forward(self, input_ids: torch.Tensor):
|
||||
return super().forward(input_ids) * self.embed_scale
|
||||
|
||||
|
||||
# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding
|
||||
class SeamlessM4Tv2SinusoidalPositionalEmbedding(nn.Module):
|
||||
"""This module produces sinusoidal positional embeddings of any length."""
|
||||
|
@ -1753,9 +1767,11 @@ class SeamlessM4Tv2Encoder(SeamlessM4Tv2PreTrainedModel):
|
|||
self.max_source_positions = config.max_position_embeddings
|
||||
|
||||
if not self.is_t2u_encoder:
|
||||
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
||||
embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
|
||||
self.embed_tokens = SeamlessM4Tv2ScaledWordEmbedding(
|
||||
config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
|
||||
)
|
||||
|
||||
if embed_tokens is not None:
|
||||
self.embed_tokens.weight = embed_tokens.weight
|
||||
|
@ -1848,7 +1864,7 @@ class SeamlessM4Tv2Encoder(SeamlessM4Tv2PreTrainedModel):
|
|||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if not self.is_t2u_encoder:
|
||||
embed_pos = self.embed_positions(input)
|
||||
|
@ -1932,14 +1948,18 @@ class SeamlessM4Tv2Decoder(SeamlessM4Tv2PreTrainedModel):
|
|||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
self.max_target_positions = config.max_position_embeddings
|
||||
self.embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0
|
||||
embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0
|
||||
|
||||
if embed_tokens is not None:
|
||||
# if embed_tokens defined, use its shape instead
|
||||
self.embed_tokens = nn.Embedding(embed_tokens.num_embeddings, embed_tokens.embedding_dim, self.padding_idx)
|
||||
self.embed_tokens = SeamlessM4Tv2ScaledWordEmbedding(
|
||||
embed_tokens.num_embeddings, embed_tokens.embedding_dim, self.padding_idx, embed_scale=embed_scale
|
||||
)
|
||||
self.embed_tokens.weight = embed_tokens.weight
|
||||
else:
|
||||
self.embed_tokens = nn.Embedding(self.vocab_size, config.hidden_size, self.padding_idx)
|
||||
self.embed_tokens = SeamlessM4Tv2ScaledWordEmbedding(
|
||||
self.vocab_size, config.hidden_size, self.padding_idx, embed_scale=embed_scale
|
||||
)
|
||||
|
||||
self.embed_positions = SeamlessM4Tv2SinusoidalPositionalEmbedding(
|
||||
self.max_target_positions,
|
||||
|
@ -2058,7 +2078,7 @@ class SeamlessM4Tv2Decoder(SeamlessM4Tv2PreTrainedModel):
|
|||
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||||
|
|
|
@ -63,6 +63,20 @@ class TrOCRLearnedPositionalEmbedding(nn.Embedding):
|
|||
return super().forward(positions + self.offset)
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->TrOCR
|
||||
class TrOCRScaledWordEmbedding(nn.Embedding):
|
||||
"""
|
||||
This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
|
||||
"""
|
||||
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0):
|
||||
super().__init__(num_embeddings, embedding_dim, padding_idx)
|
||||
self.embed_scale = embed_scale
|
||||
|
||||
def forward(self, input_ids: torch.Tensor):
|
||||
return super().forward(input_ids) * self.embed_scale
|
||||
|
||||
|
||||
class TrOCRSinusoidalPositionalEmbedding(nn.Module):
|
||||
"""This module produces sinusoidal positional embeddings of any length."""
|
||||
|
||||
|
@ -451,9 +465,11 @@ class TrOCRDecoder(TrOCRPreTrainedModel):
|
|||
self.dropout = config.dropout
|
||||
self.layerdrop = config.decoder_layerdrop
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0
|
||||
embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
||||
self.embed_tokens = TrOCRScaledWordEmbedding(
|
||||
config.vocab_size, config.hidden_size, self.padding_idx, embed_scale=embed_scale
|
||||
)
|
||||
|
||||
if config.use_learned_position_embeddings:
|
||||
self.embed_positions = TrOCRLearnedPositionalEmbedding(config.max_position_embeddings, config.hidden_size)
|
||||
|
@ -584,7 +600,7 @@ class TrOCRDecoder(TrOCRPreTrainedModel):
|
|||
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if self.config.use_learned_position_embeddings:
|
||||
embed_pos = self.embed_positions(input, past_key_values_length=past_key_values_length)
|
||||
|
|
|
@ -127,6 +127,20 @@ XGLM_INPUTS_DOCSTRING = r"""
|
|||
"""
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->XGLM
|
||||
class XGLMScaledWordEmbedding(nn.Embedding):
|
||||
"""
|
||||
This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
|
||||
"""
|
||||
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0):
|
||||
super().__init__(num_embeddings, embedding_dim, padding_idx)
|
||||
self.embed_scale = embed_scale
|
||||
|
||||
def forward(self, input_ids: torch.Tensor):
|
||||
return super().forward(input_ids) * self.embed_scale
|
||||
|
||||
|
||||
class XGLMSinusoidalPositionalEmbedding(nn.Module):
|
||||
"""This module produces sinusoidal positional embeddings of any length."""
|
||||
|
||||
|
@ -490,12 +504,14 @@ class XGLMModel(XGLMPreTrainedModel):
|
|||
self.layerdrop = config.layerdrop
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.max_target_positions = config.max_position_embeddings
|
||||
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
||||
embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
||||
|
||||
if embed_tokens is not None:
|
||||
self.embed_tokens = embed_tokens
|
||||
else:
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
|
||||
self.embed_tokens = XGLMScaledWordEmbedding(
|
||||
config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
|
||||
)
|
||||
|
||||
self.embed_positions = XGLMSinusoidalPositionalEmbedding(
|
||||
config.max_position_embeddings,
|
||||
|
@ -568,7 +584,7 @@ class XGLMModel(XGLMPreTrainedModel):
|
|||
position_ids = position_ids.unsqueeze(0)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||||
|
|
|
@ -167,6 +167,10 @@ class AlignVisionModelTest(ModelTesterMixin, unittest.TestCase):
|
|||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="AlignVisionModel does not use inputs_embeds")
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="AlignVisionModel does not support input and output embeddings")
|
||||
def test_model_common_attributes(self):
|
||||
pass
|
||||
|
@ -379,6 +383,10 @@ class AlignTextModelTest(ModelTesterMixin, unittest.TestCase):
|
|||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Align does not use inputs_embeds")
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="AlignTextModel has no base class and is not available in MODEL_MAPPING")
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
pass
|
||||
|
@ -473,6 +481,10 @@ class AlignModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Align does not use inputs_embeds")
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Retain_grad is tested in individual model tests")
|
||||
def test_retain_grad_hidden_states_attentions(self):
|
||||
pass
|
||||
|
|
|
@ -579,6 +579,29 @@ class BarkSemanticModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Te
|
|||
with torch.no_grad():
|
||||
model(**inputs)[0]
|
||||
|
||||
# override as the input arg is called "input_embeds", not "inputs_embeds"
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
|
||||
with torch.no_grad():
|
||||
out_ids = model(**inputs)[0]
|
||||
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
|
||||
wte = model.get_input_embeddings()
|
||||
inputs["input_embeds"] = wte(input_ids)
|
||||
with torch.no_grad():
|
||||
out_embeds = model(**inputs)[0]
|
||||
|
||||
self.assertTrue(torch.allclose(out_embeds, out_ids))
|
||||
|
||||
@require_torch_fp16
|
||||
def test_generate_fp16(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs()
|
||||
|
@ -645,6 +668,29 @@ class BarkCoarseModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
|
|||
with torch.no_grad():
|
||||
model(**inputs)[0]
|
||||
|
||||
# override as the input arg is called "input_embeds", not "inputs_embeds"
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
|
||||
with torch.no_grad():
|
||||
out_ids = model(**inputs)[0]
|
||||
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
|
||||
wte = model.get_input_embeddings()
|
||||
inputs["input_embeds"] = wte(input_ids)
|
||||
with torch.no_grad():
|
||||
out_embeds = model(**inputs)[0]
|
||||
|
||||
self.assertTrue(torch.allclose(out_embeds, out_ids))
|
||||
|
||||
@require_torch_fp16
|
||||
def test_generate_fp16(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs()
|
||||
|
@ -709,6 +755,10 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase):
|
|||
with torch.no_grad():
|
||||
model(**inputs)[0]
|
||||
|
||||
@unittest.skip("FineModel relies on codebook idx and does not return same logits")
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
pass
|
||||
|
||||
@require_torch_fp16
|
||||
def test_generate_fp16(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs()
|
||||
|
|
|
@ -506,6 +506,10 @@ class BridgeTowerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC
|
|||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Bridge Tower does not use inputs_embeds")
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
pass
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
|
|
|
@ -502,6 +502,10 @@ class CanineModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|||
# ViT does not use inputs_embeds
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Canine Tower does not use inputs_embeds")
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("CANINE does not have a get_input_embeddings() method.")
|
||||
def test_model_common_attributes(self):
|
||||
pass
|
||||
|
|
|
@ -247,6 +247,10 @@ class ConditionalDetrModelTest(ModelTesterMixin, GenerationTesterMixin, Pipeline
|
|||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Conditional DETR does not use inputs_embeds")
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Conditional DETR does not have a get_input_embeddings method")
|
||||
def test_model_common_attributes(self):
|
||||
pass
|
||||
|
|
|
@ -253,6 +253,10 @@ class DeformableDetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
|
|||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Deformable DETR does not use inputs_embeds")
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Deformable DETR does not have a get_input_embeddings method")
|
||||
def test_model_common_attributes(self):
|
||||
pass
|
||||
|
|
|
@ -303,6 +303,10 @@ class DetaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
|||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="DETA does not use inputs_embeds")
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="DETA does not have a get_input_embeddings method")
|
||||
def test_model_common_attributes(self):
|
||||
pass
|
||||
|
|
|
@ -247,6 +247,10 @@ class DetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
|||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="DETR does not use inputs_embeds")
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="DETR does not have a get_input_embeddings method")
|
||||
def test_model_common_attributes(self):
|
||||
pass
|
||||
|
|
|
@ -321,6 +321,10 @@ class FSMTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
|||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Input ids is required for FSMT.")
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("model weights aren't tied in FSMT.")
|
||||
def test_tie_model_weights(self):
|
||||
pass
|
||||
|
|
|
@ -182,6 +182,14 @@ class GPTSanJapaneseTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCas
|
|||
def test_model_parallelism(self):
|
||||
super().test_model_parallelism()
|
||||
|
||||
@unittest.skip(reason="Gptsan does not use inputs_embeds")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Gptsan does not use inputs_embeds")
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class GPTSanJapaneseForConditionalGenerationTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
|
@ -212,6 +220,14 @@ class GPTSanJapaneseForConditionalGenerationTest(ModelTesterMixin, GenerationTes
|
|||
def test_model_parallelism(self):
|
||||
super().test_model_parallelism()
|
||||
|
||||
@unittest.skip(reason="Gptsan does not use inputs_embeds")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Gptsan does not use inputs_embeds")
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
pass
|
||||
|
||||
@slow
|
||||
def test_logits(self):
|
||||
model = GPTSanJapaneseForConditionalGeneration.from_pretrained("Tanrei/GPTSAN-japanese")
|
||||
|
|
|
@ -382,6 +382,10 @@ class IBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|||
with torch.no_grad():
|
||||
model(**inputs)[0]
|
||||
|
||||
@unittest.skip("ibert overrides scaling to None if inputs_embeds")
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class IBertModelIntegrationTest(unittest.TestCase):
|
||||
|
|
|
@ -180,6 +180,10 @@ class Idefics2ModelTest(ModelTesterMixin, unittest.TestCase):
|
|||
def test_inputs_embeds():
|
||||
pass
|
||||
|
||||
@unittest.skip("input_embeds cannot be passed in without input_ids")
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Model does not support padding right")
|
||||
def test_flash_attn_2_generate_padding_right(self):
|
||||
pass
|
||||
|
|
|
@ -466,6 +466,31 @@ class ImageGPTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
|
|||
with torch.no_grad():
|
||||
model(**inputs)[0]
|
||||
|
||||
# override because ImageGPT main input name is `pixel_values`
|
||||
# NOTE: in latest transformers this is deprecated, `input_ids` should be used. TODO
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
|
||||
with torch.no_grad():
|
||||
out_ids = model(**inputs)[0]
|
||||
|
||||
pixel_values = inputs["pixel_values"]
|
||||
del inputs["pixel_values"]
|
||||
|
||||
wte = model.get_input_embeddings()
|
||||
inputs["inputs_embeds"] = wte(pixel_values)
|
||||
|
||||
with torch.no_grad():
|
||||
out_embeds = model(**inputs)[0]
|
||||
|
||||
self.assertTrue(torch.allclose(out_embeds, out_ids))
|
||||
|
||||
def _create_and_check_torchscript(self, config, inputs_dict):
|
||||
if not self.test_torchscript:
|
||||
return
|
||||
|
|
|
@ -265,6 +265,10 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
|||
lm_heads = model.get_output_embeddings()
|
||||
self.assertTrue(lm_heads is None or isinstance(lm_heads[0], torch.nn.Linear))
|
||||
|
||||
@unittest.skip(reason="MusicGen does not use inputs_embeds")
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
pass
|
||||
|
||||
# skip as this model doesn't support all arguments tested
|
||||
def test_model_outputs_equivalence(self):
|
||||
pass
|
||||
|
|
|
@ -268,6 +268,10 @@ class MusicgenMelodyDecoderTest(ModelTesterMixin, GenerationTesterMixin, unittes
|
|||
lm_heads = model.get_output_embeddings()
|
||||
self.assertTrue(lm_heads is None or isinstance(lm_heads[0], torch.nn.Linear))
|
||||
|
||||
@unittest.skip(reason="MusicGen melody does not use inputs_embeds")
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("this model doesn't support all arguments tested")
|
||||
def test_model_outputs_equivalence(self):
|
||||
pass
|
||||
|
|
|
@ -463,6 +463,10 @@ class SeamlessM4TModelWithSpeechInputTest(ModelTesterMixin, unittest.TestCase):
|
|||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="SeamlessM4TSpeechEncoder doesn't have an embedding layer")
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
reason="Expected missing keys serve when using SeamlessM4TForXXX.from_pretrained from a checkpoint saved by SeamlessM4TModel.save_pretrained."
|
||||
)
|
||||
|
|
|
@ -479,6 +479,10 @@ class SeamlessM4Tv2ModelWithSpeechInputTest(ModelTesterMixin, unittest.TestCase)
|
|||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="SeamlessM4TSpeechEncoder doesn't have an embedding layer")
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
reason="Expected missing keys serve when using SeamlessM4Tv2ForXXX.from_pretrained from a checkpoint saved by SeamlessM4Tv2Model.save_pretrained."
|
||||
)
|
||||
|
|
|
@ -261,6 +261,10 @@ class TableTransformerModelTest(ModelTesterMixin, GenerationTesterMixin, Pipelin
|
|||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Table Transformer does not use inputs_embeds")
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Table Transformer does not have a get_input_embeddings method")
|
||||
def test_model_common_attributes(self):
|
||||
pass
|
||||
|
|
|
@ -357,6 +357,13 @@ class ViltModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|||
def test_model_outputs_equivalence(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
reason="""VilT samples image tokens from a multinomial distribution, resulting in not deterministic
|
||||
hidden states. Cannot test equivalence on logit level"""
|
||||
)
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
pass
|
||||
|
||||
def test_attention_outputs(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.return_dict = True
|
||||
|
|
|
@ -2767,6 +2767,51 @@ class ModelTesterMixin:
|
|||
with torch.no_grad():
|
||||
model(**inputs)[0]
|
||||
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class.__name__ not in get_values(MODEL_MAPPING_NAMES):
|
||||
continue
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
model_forward_args = inspect.signature(model.forward).parameters
|
||||
if "inputs_embeds" not in model_forward_args:
|
||||
self.skipTest("This model doesn't use `inputs_embeds`")
|
||||
|
||||
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
|
||||
pad_token_id = config.pad_token_id if config.pad_token_id is not None else 1
|
||||
|
||||
wte = model.get_input_embeddings()
|
||||
if not self.is_encoder_decoder:
|
||||
input_ids = inputs["input_ids"]
|
||||
# some models infer position ids/attn mask differently when input ids
|
||||
# by check if pad_token let's make sure no padding is in input ids
|
||||
not_pad_token_id = pad_token_id + 1 if max(0, pad_token_id - 1) == 0 else pad_token_id - 1
|
||||
input_ids[input_ids == pad_token_id] = not_pad_token_id
|
||||
del inputs["input_ids"]
|
||||
inputs_embeds = wte(input_ids)
|
||||
with torch.no_grad():
|
||||
out_ids = model(input_ids=input_ids, **inputs)[0]
|
||||
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
|
||||
else:
|
||||
encoder_input_ids = inputs["input_ids"]
|
||||
decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids)
|
||||
encoder_input_ids[encoder_input_ids == pad_token_id] = max(0, pad_token_id + 1)
|
||||
decoder_input_ids[decoder_input_ids == pad_token_id] = max(0, pad_token_id + 1)
|
||||
del inputs["input_ids"]
|
||||
inputs.pop("decoder_input_ids", None)
|
||||
inputs_embeds = wte(encoder_input_ids)
|
||||
decoder_inputs_embeds = wte(decoder_input_ids)
|
||||
with torch.no_grad():
|
||||
out_ids = model(input_ids=encoder_input_ids, decoder_input_ids=decoder_input_ids, **inputs)[0]
|
||||
out_embeds = model(
|
||||
inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, **inputs
|
||||
)[0]
|
||||
self.assertTrue(torch.allclose(out_embeds, out_ids))
|
||||
|
||||
@require_torch_multi_gpu
|
||||
def test_multi_gpu_data_parallel_forward(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
|
Loading…
Reference in New Issue