[Bert2Bert] allow bert2bert + relative embeddings (#14324)

* [Bert2Bert] allow bert2bert + relative embeddings

* up

* Update README_ko.md

* up

* up
This commit is contained in:
Patrick von Platen 2021-11-09 20:26:58 +01:00 committed by GitHub
parent e4d8f517b9
commit e81d8d7fa9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 70 additions and 40 deletions

View File

@ -224,7 +224,7 @@ class BertEmbeddings(nn.Module):
class BertSelfAttention(nn.Module):
def __init__(self, config):
def __init__(self, config, position_embedding_type=None):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
@ -241,7 +241,9 @@ class BertSelfAttention(nn.Module):
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.position_embedding_type = position_embedding_type or getattr(
config, "position_embedding_type", "absolute"
)
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
@ -363,9 +365,9 @@ class BertSelfOutput(nn.Module):
class BertAttention(nn.Module):
def __init__(self, config):
def __init__(self, config, position_embedding_type=None):
super().__init__()
self.self = BertSelfAttention(config)
self.self = BertSelfAttention(config, position_embedding_type=position_embedding_type)
self.output = BertSelfOutput(config)
self.pruned_heads = set()
@ -451,7 +453,7 @@ class BertLayer(nn.Module):
if self.add_cross_attention:
if not self.is_decoder:
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
self.crossattention = BertAttention(config)
self.crossattention = BertAttention(config, position_embedding_type="absolute")
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)

View File

@ -216,7 +216,7 @@ class ElectraEmbeddings(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->Electra
class ElectraSelfAttention(nn.Module):
def __init__(self, config):
def __init__(self, config, position_embedding_type=None):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
@ -233,7 +233,9 @@ class ElectraSelfAttention(nn.Module):
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.position_embedding_type = position_embedding_type or getattr(
config, "position_embedding_type", "absolute"
)
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
@ -357,9 +359,9 @@ class ElectraSelfOutput(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Electra
class ElectraAttention(nn.Module):
def __init__(self, config):
def __init__(self, config, position_embedding_type=None):
super().__init__()
self.self = ElectraSelfAttention(config)
self.self = ElectraSelfAttention(config, position_embedding_type=position_embedding_type)
self.output = ElectraSelfOutput(config)
self.pruned_heads = set()
@ -448,7 +450,7 @@ class ElectraLayer(nn.Module):
if self.add_cross_attention:
if not self.is_decoder:
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
self.crossattention = ElectraAttention(config)
self.crossattention = ElectraAttention(config, position_embedding_type="absolute")
self.intermediate = ElectraIntermediate(config)
self.output = ElectraOutput(config)

View File

@ -132,7 +132,7 @@ class LayoutLMEmbeddings(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->LayoutLM
class LayoutLMSelfAttention(nn.Module):
def __init__(self, config):
def __init__(self, config, position_embedding_type=None):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
@ -149,7 +149,9 @@ class LayoutLMSelfAttention(nn.Module):
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.position_embedding_type = position_embedding_type or getattr(
config, "position_embedding_type", "absolute"
)
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
@ -273,9 +275,9 @@ class LayoutLMSelfOutput(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->LayoutLM
class LayoutLMAttention(nn.Module):
def __init__(self, config):
def __init__(self, config, position_embedding_type=None):
super().__init__()
self.self = LayoutLMSelfAttention(config)
self.self = LayoutLMSelfAttention(config, position_embedding_type=position_embedding_type)
self.output = LayoutLMSelfOutput(config)
self.pruned_heads = set()
@ -364,7 +366,7 @@ class LayoutLMLayer(nn.Module):
if self.add_cross_attention:
if not self.is_decoder:
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
self.crossattention = LayoutLMAttention(config)
self.crossattention = LayoutLMAttention(config, position_embedding_type="absolute")
self.intermediate = LayoutLMIntermediate(config)
self.output = LayoutLMOutput(config)

View File

@ -195,7 +195,7 @@ class MegatronBertEmbeddings(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->MegatronBert
class MegatronBertSelfAttention(nn.Module):
def __init__(self, config):
def __init__(self, config, position_embedding_type=None):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
@ -212,7 +212,9 @@ class MegatronBertSelfAttention(nn.Module):
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.position_embedding_type = position_embedding_type or getattr(
config, "position_embedding_type", "absolute"
)
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)

View File

@ -328,7 +328,6 @@ class RemBertSelfOutput(nn.Module):
return hidden_states
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->RemBert
class RemBertAttention(nn.Module):
def __init__(self, config):
super().__init__()
@ -336,6 +335,7 @@ class RemBertAttention(nn.Module):
self.output = RemBertSelfOutput(config)
self.pruned_heads = set()
# Copied from transformers.models.bert.modeling_bert.BertAttention.prune_heads
def prune_heads(self, heads):
if len(heads) == 0:
return
@ -354,6 +354,7 @@ class RemBertAttention(nn.Module):
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
# Copied from transformers.models.bert.modeling_bert.BertAttention.forward
def forward(
self,
hidden_states,
@ -409,7 +410,6 @@ class RemBertOutput(nn.Module):
return hidden_states
# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->RemBert
class RemBertLayer(nn.Module):
def __init__(self, config):
super().__init__()
@ -425,6 +425,7 @@ class RemBertLayer(nn.Module):
self.intermediate = RemBertIntermediate(config)
self.output = RemBertOutput(config)
# Copied from transformers.models.bert.modeling_bert.BertLayer.forward
def forward(
self,
hidden_states,
@ -489,6 +490,7 @@ class RemBertLayer(nn.Module):
return outputs
# Copied from transformers.models.bert.modeling_bert.BertLayer.feed_forward_chunk
def feed_forward_chunk(self, attention_output):
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)

View File

@ -159,7 +159,7 @@ class RobertaEmbeddings(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->Roberta
class RobertaSelfAttention(nn.Module):
def __init__(self, config):
def __init__(self, config, position_embedding_type=None):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
@ -176,7 +176,9 @@ class RobertaSelfAttention(nn.Module):
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.position_embedding_type = position_embedding_type or getattr(
config, "position_embedding_type", "absolute"
)
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
@ -300,9 +302,9 @@ class RobertaSelfOutput(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Roberta
class RobertaAttention(nn.Module):
def __init__(self, config):
def __init__(self, config, position_embedding_type=None):
super().__init__()
self.self = RobertaSelfAttention(config)
self.self = RobertaSelfAttention(config, position_embedding_type=position_embedding_type)
self.output = RobertaSelfOutput(config)
self.pruned_heads = set()
@ -391,7 +393,7 @@ class RobertaLayer(nn.Module):
if self.add_cross_attention:
if not self.is_decoder:
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
self.crossattention = RobertaAttention(config)
self.crossattention = RobertaAttention(config, position_embedding_type="absolute")
self.intermediate = RobertaIntermediate(config)
self.output = RobertaOutput(config)

View File

@ -367,14 +367,12 @@ class RoFormerSelfOutput(nn.Module):
class RoFormerAttention(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertAttention.__init__ with Bert->RoFormer
def __init__(self, config):
super().__init__()
self.self = RoFormerSelfAttention(config)
self.output = RoFormerSelfOutput(config)
self.pruned_heads = set()
# End Copy
# Copied from transformers.models.bert.modeling_bert.BertAttention.prune_heads
def prune_heads(self, heads):
if len(heads) == 0:
@ -453,7 +451,6 @@ class RoFormerOutput(nn.Module):
class RoFormerLayer(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertLayer.__init__ with Bert->RoFormer
def __init__(self, config):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
@ -468,7 +465,6 @@ class RoFormerLayer(nn.Module):
self.intermediate = RoFormerIntermediate(config)
self.output = RoFormerOutput(config)
# End Copy
def forward(
self,
hidden_states,

View File

@ -99,7 +99,7 @@ class SplinterEmbeddings(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->Splinter
class SplinterSelfAttention(nn.Module):
def __init__(self, config):
def __init__(self, config, position_embedding_type=None):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
@ -116,7 +116,9 @@ class SplinterSelfAttention(nn.Module):
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.position_embedding_type = position_embedding_type or getattr(
config, "position_embedding_type", "absolute"
)
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
@ -240,9 +242,9 @@ class SplinterSelfOutput(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Splinter
class SplinterAttention(nn.Module):
def __init__(self, config):
def __init__(self, config, position_embedding_type=None):
super().__init__()
self.self = SplinterSelfAttention(config)
self.self = SplinterSelfAttention(config, position_embedding_type=position_embedding_type)
self.output = SplinterSelfOutput(config)
self.pruned_heads = set()
@ -331,7 +333,7 @@ class SplinterLayer(nn.Module):
if self.add_cross_attention:
if not self.is_decoder:
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
self.crossattention = SplinterAttention(config)
self.crossattention = SplinterAttention(config, position_embedding_type="absolute")
self.intermediate = SplinterIntermediate(config)
self.output = SplinterOutput(config)

View File

@ -456,7 +456,6 @@ class TapasSelfOutput(nn.Module):
return hidden_states
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Tapas
class TapasAttention(nn.Module):
def __init__(self, config):
super().__init__()
@ -464,6 +463,7 @@ class TapasAttention(nn.Module):
self.output = TapasSelfOutput(config)
self.pruned_heads = set()
# Copied from transformers.models.bert.modeling_bert.BertAttention.prune_heads
def prune_heads(self, heads):
if len(heads) == 0:
return
@ -482,6 +482,7 @@ class TapasAttention(nn.Module):
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
# Copied from transformers.models.bert.modeling_bert.BertAttention.forward
def forward(
self,
hidden_states,
@ -537,7 +538,6 @@ class TapasOutput(nn.Module):
return hidden_states
# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Tapas
class TapasLayer(nn.Module):
def __init__(self, config):
super().__init__()
@ -553,6 +553,7 @@ class TapasLayer(nn.Module):
self.intermediate = TapasIntermediate(config)
self.output = TapasOutput(config)
# Copied from transformers.models.bert.modeling_bert.BertLayer.forward
def forward(
self,
hidden_states,
@ -617,6 +618,7 @@ class TapasLayer(nn.Module):
return outputs
# Copied from transformers.models.bert.modeling_bert.BertLayer.feed_forward_chunk
def feed_forward_chunk(self, attention_output):
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)

View File

@ -203,7 +203,7 @@ class {{cookiecutter.camelcase_modelname}}Embeddings(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->{{cookiecutter.camelcase_modelname}}
class {{cookiecutter.camelcase_modelname}}SelfAttention(nn.Module):
def __init__(self, config):
def __init__(self, config, position_embedding_type=None):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
@ -220,7 +220,7 @@ class {{cookiecutter.camelcase_modelname}}SelfAttention(nn.Module):
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.position_embedding_type = position_embedding_type or getattr(config, "position_embedding_type", "absolute")
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
@ -344,9 +344,9 @@ class {{cookiecutter.camelcase_modelname}}SelfOutput(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->{{cookiecutter.camelcase_modelname}}
class {{cookiecutter.camelcase_modelname}}Attention(nn.Module):
def __init__(self, config):
def __init__(self, config, position_embedding_type=None):
super().__init__()
self.self = {{cookiecutter.camelcase_modelname}}SelfAttention(config)
self.self = {{cookiecutter.camelcase_modelname}}SelfAttention(config, position_embedding_type=position_embedding_type)
self.output = {{cookiecutter.camelcase_modelname}}SelfOutput(config)
self.pruned_heads = set()
@ -434,7 +434,7 @@ class {{cookiecutter.camelcase_modelname}}Layer(nn.Module):
self.add_cross_attention = config.add_cross_attention
if self.add_cross_attention:
assert self.is_decoder, f"{self} should be used as a decoder model if cross attention is added"
self.crossattention = {{cookiecutter.camelcase_modelname}}Attention(config)
self.crossattention = {{cookiecutter.camelcase_modelname}}Attention(config, position_embedding_type="absolute")
self.intermediate = {{cookiecutter.camelcase_modelname}}Intermediate(config)
self.output = {{cookiecutter.camelcase_modelname}}Output(config)

View File

@ -567,6 +567,24 @@ class BertEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
"labels": decoder_token_labels,
}
def test_relative_position_embeds(self):
config_and_inputs = self.prepare_config_and_inputs()
encoder_config = config_and_inputs["config"]
decoder_config = config_and_inputs["decoder_config"]
encoder_config.position_embedding_type = "relative_key_query"
decoder_config.position_embedding_type = "relative_key_query"
config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder_config, decoder_config)
model = EncoderDecoderModel(config).eval().to(torch_device)
logits = model(
input_ids=config_and_inputs["input_ids"], decoder_input_ids=config_and_inputs["decoder_input_ids"]
).logits
self.assertTrue(logits.shape, (13, 7))
@slow
def test_bert2bert_summarization(self):
model = EncoderDecoderModel.from_pretrained("patrickvonplaten/bert2bert-cnn_dailymail-fp16")