[Bert2Bert] allow bert2bert + relative embeddings (#14324)
* [Bert2Bert] allow bert2bert + relative embeddings * up * Update README_ko.md * up * up
This commit is contained in:
parent
e4d8f517b9
commit
e81d8d7fa9
|
@ -224,7 +224,7 @@ class BertEmbeddings(nn.Module):
|
|||
|
||||
|
||||
class BertSelfAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
super().__init__()
|
||||
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
||||
raise ValueError(
|
||||
|
@ -241,7 +241,9 @@ class BertSelfAttention(nn.Module):
|
|||
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
||||
self.position_embedding_type = position_embedding_type or getattr(
|
||||
config, "position_embedding_type", "absolute"
|
||||
)
|
||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
||||
|
@ -363,9 +365,9 @@ class BertSelfOutput(nn.Module):
|
|||
|
||||
|
||||
class BertAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
super().__init__()
|
||||
self.self = BertSelfAttention(config)
|
||||
self.self = BertSelfAttention(config, position_embedding_type=position_embedding_type)
|
||||
self.output = BertSelfOutput(config)
|
||||
self.pruned_heads = set()
|
||||
|
||||
|
@ -451,7 +453,7 @@ class BertLayer(nn.Module):
|
|||
if self.add_cross_attention:
|
||||
if not self.is_decoder:
|
||||
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
|
||||
self.crossattention = BertAttention(config)
|
||||
self.crossattention = BertAttention(config, position_embedding_type="absolute")
|
||||
self.intermediate = BertIntermediate(config)
|
||||
self.output = BertOutput(config)
|
||||
|
||||
|
|
|
@ -216,7 +216,7 @@ class ElectraEmbeddings(nn.Module):
|
|||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->Electra
|
||||
class ElectraSelfAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
super().__init__()
|
||||
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
||||
raise ValueError(
|
||||
|
@ -233,7 +233,9 @@ class ElectraSelfAttention(nn.Module):
|
|||
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
||||
self.position_embedding_type = position_embedding_type or getattr(
|
||||
config, "position_embedding_type", "absolute"
|
||||
)
|
||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
||||
|
@ -357,9 +359,9 @@ class ElectraSelfOutput(nn.Module):
|
|||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Electra
|
||||
class ElectraAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
super().__init__()
|
||||
self.self = ElectraSelfAttention(config)
|
||||
self.self = ElectraSelfAttention(config, position_embedding_type=position_embedding_type)
|
||||
self.output = ElectraSelfOutput(config)
|
||||
self.pruned_heads = set()
|
||||
|
||||
|
@ -448,7 +450,7 @@ class ElectraLayer(nn.Module):
|
|||
if self.add_cross_attention:
|
||||
if not self.is_decoder:
|
||||
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
|
||||
self.crossattention = ElectraAttention(config)
|
||||
self.crossattention = ElectraAttention(config, position_embedding_type="absolute")
|
||||
self.intermediate = ElectraIntermediate(config)
|
||||
self.output = ElectraOutput(config)
|
||||
|
||||
|
|
|
@ -132,7 +132,7 @@ class LayoutLMEmbeddings(nn.Module):
|
|||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->LayoutLM
|
||||
class LayoutLMSelfAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
super().__init__()
|
||||
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
||||
raise ValueError(
|
||||
|
@ -149,7 +149,9 @@ class LayoutLMSelfAttention(nn.Module):
|
|||
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
||||
self.position_embedding_type = position_embedding_type or getattr(
|
||||
config, "position_embedding_type", "absolute"
|
||||
)
|
||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
||||
|
@ -273,9 +275,9 @@ class LayoutLMSelfOutput(nn.Module):
|
|||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->LayoutLM
|
||||
class LayoutLMAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
super().__init__()
|
||||
self.self = LayoutLMSelfAttention(config)
|
||||
self.self = LayoutLMSelfAttention(config, position_embedding_type=position_embedding_type)
|
||||
self.output = LayoutLMSelfOutput(config)
|
||||
self.pruned_heads = set()
|
||||
|
||||
|
@ -364,7 +366,7 @@ class LayoutLMLayer(nn.Module):
|
|||
if self.add_cross_attention:
|
||||
if not self.is_decoder:
|
||||
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
|
||||
self.crossattention = LayoutLMAttention(config)
|
||||
self.crossattention = LayoutLMAttention(config, position_embedding_type="absolute")
|
||||
self.intermediate = LayoutLMIntermediate(config)
|
||||
self.output = LayoutLMOutput(config)
|
||||
|
||||
|
|
|
@ -195,7 +195,7 @@ class MegatronBertEmbeddings(nn.Module):
|
|||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->MegatronBert
|
||||
class MegatronBertSelfAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
super().__init__()
|
||||
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
||||
raise ValueError(
|
||||
|
@ -212,7 +212,9 @@ class MegatronBertSelfAttention(nn.Module):
|
|||
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
||||
self.position_embedding_type = position_embedding_type or getattr(
|
||||
config, "position_embedding_type", "absolute"
|
||||
)
|
||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
||||
|
|
|
@ -328,7 +328,6 @@ class RemBertSelfOutput(nn.Module):
|
|||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->RemBert
|
||||
class RemBertAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
@ -336,6 +335,7 @@ class RemBertAttention(nn.Module):
|
|||
self.output = RemBertSelfOutput(config)
|
||||
self.pruned_heads = set()
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertAttention.prune_heads
|
||||
def prune_heads(self, heads):
|
||||
if len(heads) == 0:
|
||||
return
|
||||
|
@ -354,6 +354,7 @@ class RemBertAttention(nn.Module):
|
|||
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
||||
self.pruned_heads = self.pruned_heads.union(heads)
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertAttention.forward
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
|
@ -409,7 +410,6 @@ class RemBertOutput(nn.Module):
|
|||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->RemBert
|
||||
class RemBertLayer(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
@ -425,6 +425,7 @@ class RemBertLayer(nn.Module):
|
|||
self.intermediate = RemBertIntermediate(config)
|
||||
self.output = RemBertOutput(config)
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertLayer.forward
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
|
@ -489,6 +490,7 @@ class RemBertLayer(nn.Module):
|
|||
|
||||
return outputs
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertLayer.feed_forward_chunk
|
||||
def feed_forward_chunk(self, attention_output):
|
||||
intermediate_output = self.intermediate(attention_output)
|
||||
layer_output = self.output(intermediate_output, attention_output)
|
||||
|
|
|
@ -159,7 +159,7 @@ class RobertaEmbeddings(nn.Module):
|
|||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->Roberta
|
||||
class RobertaSelfAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
super().__init__()
|
||||
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
||||
raise ValueError(
|
||||
|
@ -176,7 +176,9 @@ class RobertaSelfAttention(nn.Module):
|
|||
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
||||
self.position_embedding_type = position_embedding_type or getattr(
|
||||
config, "position_embedding_type", "absolute"
|
||||
)
|
||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
||||
|
@ -300,9 +302,9 @@ class RobertaSelfOutput(nn.Module):
|
|||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Roberta
|
||||
class RobertaAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
super().__init__()
|
||||
self.self = RobertaSelfAttention(config)
|
||||
self.self = RobertaSelfAttention(config, position_embedding_type=position_embedding_type)
|
||||
self.output = RobertaSelfOutput(config)
|
||||
self.pruned_heads = set()
|
||||
|
||||
|
@ -391,7 +393,7 @@ class RobertaLayer(nn.Module):
|
|||
if self.add_cross_attention:
|
||||
if not self.is_decoder:
|
||||
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
|
||||
self.crossattention = RobertaAttention(config)
|
||||
self.crossattention = RobertaAttention(config, position_embedding_type="absolute")
|
||||
self.intermediate = RobertaIntermediate(config)
|
||||
self.output = RobertaOutput(config)
|
||||
|
||||
|
|
|
@ -367,14 +367,12 @@ class RoFormerSelfOutput(nn.Module):
|
|||
|
||||
|
||||
class RoFormerAttention(nn.Module):
|
||||
# Copied from transformers.models.bert.modeling_bert.BertAttention.__init__ with Bert->RoFormer
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.self = RoFormerSelfAttention(config)
|
||||
self.output = RoFormerSelfOutput(config)
|
||||
self.pruned_heads = set()
|
||||
|
||||
# End Copy
|
||||
# Copied from transformers.models.bert.modeling_bert.BertAttention.prune_heads
|
||||
def prune_heads(self, heads):
|
||||
if len(heads) == 0:
|
||||
|
@ -453,7 +451,6 @@ class RoFormerOutput(nn.Module):
|
|||
|
||||
|
||||
class RoFormerLayer(nn.Module):
|
||||
# Copied from transformers.models.bert.modeling_bert.BertLayer.__init__ with Bert->RoFormer
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
|
@ -468,7 +465,6 @@ class RoFormerLayer(nn.Module):
|
|||
self.intermediate = RoFormerIntermediate(config)
|
||||
self.output = RoFormerOutput(config)
|
||||
|
||||
# End Copy
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
|
|
|
@ -99,7 +99,7 @@ class SplinterEmbeddings(nn.Module):
|
|||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->Splinter
|
||||
class SplinterSelfAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
super().__init__()
|
||||
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
||||
raise ValueError(
|
||||
|
@ -116,7 +116,9 @@ class SplinterSelfAttention(nn.Module):
|
|||
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
||||
self.position_embedding_type = position_embedding_type or getattr(
|
||||
config, "position_embedding_type", "absolute"
|
||||
)
|
||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
||||
|
@ -240,9 +242,9 @@ class SplinterSelfOutput(nn.Module):
|
|||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Splinter
|
||||
class SplinterAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
super().__init__()
|
||||
self.self = SplinterSelfAttention(config)
|
||||
self.self = SplinterSelfAttention(config, position_embedding_type=position_embedding_type)
|
||||
self.output = SplinterSelfOutput(config)
|
||||
self.pruned_heads = set()
|
||||
|
||||
|
@ -331,7 +333,7 @@ class SplinterLayer(nn.Module):
|
|||
if self.add_cross_attention:
|
||||
if not self.is_decoder:
|
||||
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
|
||||
self.crossattention = SplinterAttention(config)
|
||||
self.crossattention = SplinterAttention(config, position_embedding_type="absolute")
|
||||
self.intermediate = SplinterIntermediate(config)
|
||||
self.output = SplinterOutput(config)
|
||||
|
||||
|
|
|
@ -456,7 +456,6 @@ class TapasSelfOutput(nn.Module):
|
|||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Tapas
|
||||
class TapasAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
@ -464,6 +463,7 @@ class TapasAttention(nn.Module):
|
|||
self.output = TapasSelfOutput(config)
|
||||
self.pruned_heads = set()
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertAttention.prune_heads
|
||||
def prune_heads(self, heads):
|
||||
if len(heads) == 0:
|
||||
return
|
||||
|
@ -482,6 +482,7 @@ class TapasAttention(nn.Module):
|
|||
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
||||
self.pruned_heads = self.pruned_heads.union(heads)
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertAttention.forward
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
|
@ -537,7 +538,6 @@ class TapasOutput(nn.Module):
|
|||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Tapas
|
||||
class TapasLayer(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
@ -553,6 +553,7 @@ class TapasLayer(nn.Module):
|
|||
self.intermediate = TapasIntermediate(config)
|
||||
self.output = TapasOutput(config)
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertLayer.forward
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
|
@ -617,6 +618,7 @@ class TapasLayer(nn.Module):
|
|||
|
||||
return outputs
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertLayer.feed_forward_chunk
|
||||
def feed_forward_chunk(self, attention_output):
|
||||
intermediate_output = self.intermediate(attention_output)
|
||||
layer_output = self.output(intermediate_output, attention_output)
|
||||
|
|
|
@ -203,7 +203,7 @@ class {{cookiecutter.camelcase_modelname}}Embeddings(nn.Module):
|
|||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->{{cookiecutter.camelcase_modelname}}
|
||||
class {{cookiecutter.camelcase_modelname}}SelfAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
super().__init__()
|
||||
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
||||
raise ValueError(
|
||||
|
@ -220,7 +220,7 @@ class {{cookiecutter.camelcase_modelname}}SelfAttention(nn.Module):
|
|||
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
||||
self.position_embedding_type = position_embedding_type or getattr(config, "position_embedding_type", "absolute")
|
||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
||||
|
@ -344,9 +344,9 @@ class {{cookiecutter.camelcase_modelname}}SelfOutput(nn.Module):
|
|||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->{{cookiecutter.camelcase_modelname}}
|
||||
class {{cookiecutter.camelcase_modelname}}Attention(nn.Module):
|
||||
def __init__(self, config):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
super().__init__()
|
||||
self.self = {{cookiecutter.camelcase_modelname}}SelfAttention(config)
|
||||
self.self = {{cookiecutter.camelcase_modelname}}SelfAttention(config, position_embedding_type=position_embedding_type)
|
||||
self.output = {{cookiecutter.camelcase_modelname}}SelfOutput(config)
|
||||
self.pruned_heads = set()
|
||||
|
||||
|
@ -434,7 +434,7 @@ class {{cookiecutter.camelcase_modelname}}Layer(nn.Module):
|
|||
self.add_cross_attention = config.add_cross_attention
|
||||
if self.add_cross_attention:
|
||||
assert self.is_decoder, f"{self} should be used as a decoder model if cross attention is added"
|
||||
self.crossattention = {{cookiecutter.camelcase_modelname}}Attention(config)
|
||||
self.crossattention = {{cookiecutter.camelcase_modelname}}Attention(config, position_embedding_type="absolute")
|
||||
self.intermediate = {{cookiecutter.camelcase_modelname}}Intermediate(config)
|
||||
self.output = {{cookiecutter.camelcase_modelname}}Output(config)
|
||||
|
||||
|
|
|
@ -567,6 +567,24 @@ class BertEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
|
|||
"labels": decoder_token_labels,
|
||||
}
|
||||
|
||||
def test_relative_position_embeds(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
|
||||
encoder_config = config_and_inputs["config"]
|
||||
decoder_config = config_and_inputs["decoder_config"]
|
||||
|
||||
encoder_config.position_embedding_type = "relative_key_query"
|
||||
decoder_config.position_embedding_type = "relative_key_query"
|
||||
|
||||
config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder_config, decoder_config)
|
||||
model = EncoderDecoderModel(config).eval().to(torch_device)
|
||||
|
||||
logits = model(
|
||||
input_ids=config_and_inputs["input_ids"], decoder_input_ids=config_and_inputs["decoder_input_ids"]
|
||||
).logits
|
||||
|
||||
self.assertTrue(logits.shape, (13, 7))
|
||||
|
||||
@slow
|
||||
def test_bert2bert_summarization(self):
|
||||
model = EncoderDecoderModel.from_pretrained("patrickvonplaten/bert2bert-cnn_dailymail-fp16")
|
||||
|
|
Loading…
Reference in New Issue