Fix dpr<>bart config for RAG (#8808)
* correct dpr test and bert pos fault * fix dpr bert config problem * fix layoutlm * add config to dpr as well
This commit is contained in:
parent
a2cf37595e
commit
a7d46a0609
|
@ -214,7 +214,7 @@ class AlbertEmbeddings(nn.Module):
|
|||
|
||||
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
||||
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
||||
self.position_embedding_type = config.position_embedding_type
|
||||
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
|
||||
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
|
||||
|
@ -268,7 +268,7 @@ class AlbertAttention(nn.Module):
|
|||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.pruned_heads = set()
|
||||
|
||||
self.position_embedding_type = config.position_embedding_type
|
||||
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
||||
|
|
|
@ -178,7 +178,7 @@ class BertEmbeddings(nn.Module):
|
|||
|
||||
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
||||
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
||||
self.position_embedding_type = config.position_embedding_type
|
||||
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
||||
|
||||
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
|
||||
if input_ids is not None:
|
||||
|
@ -225,7 +225,7 @@ class BertSelfAttention(nn.Module):
|
|||
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
self.position_embedding_type = config.position_embedding_type
|
||||
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
||||
|
|
|
@ -71,6 +71,13 @@ class DPRConfig(PretrainedConfig):
|
|||
The epsilon used by the layer normalization layers.
|
||||
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
|
||||
position_embedding_type (:obj:`str`, `optional`, defaults to :obj:`"absolute"`):
|
||||
Type of position embedding. Choose one of :obj:`"absolute"`, :obj:`"relative_key"`,
|
||||
:obj:`"relative_key_query"`. For positional embeddings use :obj:`"absolute"`. For more information on
|
||||
:obj:`"relative_key"`, please refer to `Self-Attention with Relative Position Representations (Shaw et al.)
|
||||
<https://arxiv.org/abs/1803.02155>`__. For more information on :obj:`"relative_key_query"`, please refer to
|
||||
`Method 4` in `Improve Transformer Models with Better Relative Position Embeddings (Huang et al.)
|
||||
<https://arxiv.org/abs/2009.13658>`__.
|
||||
projection_dim (:obj:`int`, `optional`, defaults to 0):
|
||||
Dimension of the projection for the context and question encoders. If it is set to zero (default), then no
|
||||
projection is done.
|
||||
|
@ -93,6 +100,7 @@ class DPRConfig(PretrainedConfig):
|
|||
layer_norm_eps=1e-12,
|
||||
pad_token_id=0,
|
||||
gradient_checkpointing=False,
|
||||
position_embedding_type="absolute",
|
||||
projection_dim: int = 0,
|
||||
**kwargs
|
||||
):
|
||||
|
@ -112,3 +120,4 @@ class DPRConfig(PretrainedConfig):
|
|||
self.layer_norm_eps = layer_norm_eps
|
||||
self.gradient_checkpointing = gradient_checkpointing
|
||||
self.projection_dim = projection_dim
|
||||
self.position_embedding_type = position_embedding_type
|
||||
|
|
|
@ -165,7 +165,7 @@ class ElectraEmbeddings(nn.Module):
|
|||
|
||||
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
||||
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
||||
self.position_embedding_type = config.position_embedding_type
|
||||
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
|
||||
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
|
||||
|
@ -214,7 +214,7 @@ class ElectraSelfAttention(nn.Module):
|
|||
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
self.position_embedding_type = config.position_embedding_type
|
||||
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
||||
|
|
|
@ -146,7 +146,7 @@ class LayoutLMSelfAttention(nn.Module):
|
|||
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
self.position_embedding_type = config.position_embedding_type
|
||||
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
||||
|
|
|
@ -83,7 +83,7 @@ class RobertaEmbeddings(nn.Module):
|
|||
|
||||
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
||||
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
||||
self.position_embedding_type = config.position_embedding_type
|
||||
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
||||
|
||||
# End copy
|
||||
self.padding_idx = config.pad_token_id
|
||||
|
@ -162,7 +162,7 @@ class RobertaSelfAttention(nn.Module):
|
|||
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
self.position_embedding_type = config.position_embedding_type
|
||||
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
||||
|
|
|
@ -26,7 +26,7 @@ from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention
|
|||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import BertConfig, DPRConfig, DPRContextEncoder, DPRQuestionEncoder, DPRReader
|
||||
from transformers import DPRConfig, DPRContextEncoder, DPRQuestionEncoder, DPRReader
|
||||
from transformers.models.dpr.modeling_dpr import (
|
||||
DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
DPR_QUESTION_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
|
@ -104,7 +104,8 @@ class DPRModelTester:
|
|||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
||||
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
||||
|
||||
config = BertConfig(
|
||||
config = DPRConfig(
|
||||
projection_dim=self.projection_dim,
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
|
@ -115,14 +116,12 @@ class DPRModelTester:
|
|||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
type_vocab_size=self.type_vocab_size,
|
||||
is_decoder=False,
|
||||
initializer_range=self.initializer_range,
|
||||
)
|
||||
config = DPRConfig(projection_dim=self.projection_dim, **config.to_dict())
|
||||
|
||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
||||
def create_and_check_dpr_context_encoder(
|
||||
def create_and_check_context_encoder(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = DPRContextEncoder(config=config)
|
||||
|
@ -133,7 +132,7 @@ class DPRModelTester:
|
|||
result = model(input_ids)
|
||||
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.projection_dim or self.hidden_size))
|
||||
|
||||
def create_and_check_dpr_question_encoder(
|
||||
def create_and_check_question_encoder(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = DPRQuestionEncoder(config=config)
|
||||
|
@ -144,7 +143,7 @@ class DPRModelTester:
|
|||
result = model(input_ids)
|
||||
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.projection_dim or self.hidden_size))
|
||||
|
||||
def create_and_check_dpr_reader(
|
||||
def create_and_check_reader(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = DPRReader(config=config)
|
||||
|
@ -199,17 +198,17 @@ class DPRModelTest(ModelTesterMixin, unittest.TestCase):
|
|||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_dpr_context_encoder_model(self):
|
||||
def test_context_encoder_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_dpr_context_encoder(*config_and_inputs)
|
||||
self.model_tester.create_and_check_context_encoder(*config_and_inputs)
|
||||
|
||||
def test_dpr_question_encoder_model(self):
|
||||
def test_question_encoder_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_dpr_question_encoder(*config_and_inputs)
|
||||
self.model_tester.create_and_check_question_encoder(*config_and_inputs)
|
||||
|
||||
def test_dpr_reader_model(self):
|
||||
def test_reader_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_dpr_reader(*config_and_inputs)
|
||||
self.model_tester.create_and_check_reader(*config_and_inputs)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
|
|
Loading…
Reference in New Issue