[Pretrained Model] Add resize_position_embeddings (#13559)

* finish

* delete bogus file

* correct some stuff

* finish

* finish
This commit is contained in:
Patrick von Platen 2021-09-15 19:03:56 +02:00 committed by GitHub
parent c783e14887
commit 95f933ea85
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 413 additions and 13 deletions

View File

@ -99,6 +99,13 @@ class ModelArguments:
"with private models)."
},
)
resize_position_embeddings: Optional[bool] = field(
default=None,
metadata={
"help": "Whether to automatically resize the position embeddings if `max_source_length` exceeds "
"the model's position embeddings."
},
)
@dataclass
@ -366,6 +373,25 @@ def main():
if model.config.decoder_start_token_id is None:
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
if (
hasattr(model.config, "max_position_embeddings")
and model.config.max_position_embeddings < data_args.max_source_length
):
if model_args.resize_position_embeddings is None:
logger.warning(
f"Increasing the model's number of position embedding vectors from {model.config.max_position_embedding} "
f"to {data_args.max_source_length}."
)
model.resize_position_embeddings(data_args.max_source_length)
elif model_args.resize_position_embeddings:
model.resize_position_embeddings(data_args.max_source_length)
else:
raise ValueError(
f"`--max_source_length` is set to {data_args.max_source_length}, but the model only has {model.config.max_position_embeddings}"
f" position encodings. Consider either reducing `--max_source_length` to {model.config.max_position_embeddings} or to automatically "
"resize the model's position encodings by passing `--resize_position_embeddings`."
)
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
# Preprocessing the datasets.

View File

@ -887,6 +887,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
return new_lm_head
def resize_position_embeddings(self, new_num_position_embeddings: int):
raise NotImplementedError(
f"`resize_position_embeddings` is not implemented for {self.__class__}`. To implement it, you should "
f"overwrite this method in the class {self.__class__} in `modeling_{self.__class__.__module__}.py`"
)
def get_position_embeddings(self) -> Union[nn.Embedding, Tuple[nn.Embedding]]:
raise NotImplementedError(
f"`get_position_embeddings` is not implemented for {self.__class__}`. To implement it, you should "
f"overwrite this method in the class {self.__class__} in `modeling_{self.__class__.__module__}.py`"
)
def init_weights(self):
"""
If needed prunes and maybe initializes weights.

View File

@ -2833,7 +2833,7 @@ class BigBirdPegasusDecoderWrapper(BigBirdPegasusPreTrainedModel):
return self.decoder(*args, **kwargs)
# Copied from transformers.models.pegasus.modeling_pegasus.PegasusForCausalLM with Pegasus->BigBirdPegasus, 'facebook/bart-large'->"google/bigbird-pegasus-large-arxiv"
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->BigBirdPegasus, 'facebook/bart-large'->"google/bigbird-pegasus-large-arxiv"
class BigBirdPegasusForCausalLM(BigBirdPegasusPreTrainedModel):
def __init__(self, config):
super().__init__(config)

View File

@ -442,6 +442,67 @@ class DistilBertModel(DistilBertPreTrainedModel):
self.init_weights()
def get_position_embeddings(self) -> nn.Embedding:
"""
Returns the position embeddings
"""
return self.embeddings.position_embeddings
def resize_position_embeddings(self, new_num_position_embeddings: int):
"""
Resizes position embeddings of the model if :obj:`new_num_position_embeddings !=
config.max_position_embeddings`.
Arguments:
new_num_position_embeddings (:obj:`int`):
The number of new position embedding matrix. If position embeddings are learned, increasing the size
will add newly initialized vectors at the end, whereas reducing the size will remove vectors from the
end. If position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the
size will add correct vectors at the end following the position encoding algorithm, whereas reducing
the size will remove vectors from the end.
"""
num_position_embeds_diff = new_num_position_embeddings - self.config.max_position_embeddings
# no resizing needs to be done if the length stays the same
if num_position_embeds_diff == 0:
return
logger.info(f"Setting `config.max_position_embeddings={new_num_position_embeddings}`...")
self.config.max_position_embeddings = new_num_position_embeddings
old_position_embeddings_weight = self.embeddings.position_embeddings.weight.clone()
self.embeddings.position_embeddings = nn.Embedding(self.config.max_position_embeddings, self.config.dim)
if self.config.sinusoidal_pos_embds:
if is_deepspeed_zero3_enabled():
import deepspeed
with deepspeed.zero.GatheredParameters(self.embeddings.position_embeddings.weight, modifier_rank=0):
if torch.distributed.get_rank() == 0:
create_sinusoidal_embeddings(
n_pos=self.config.max_position_embeddings,
dim=self.config.dim,
out=self.embeddings.position_embeddings.weight,
)
else:
create_sinusoidal_embeddings(
n_pos=self.config.max_position_embeddings,
dim=self.config.dim,
out=self.embeddings.position_embeddings.weight,
)
else:
with torch.no_grad():
if num_position_embeds_diff > 0:
self.embeddings.position_embeddings.weight[:-num_position_embeds_diff] = nn.Parameter(
old_position_embeddings_weight
)
else:
self.embeddings.position_embeddings.weight = nn.Parameter(
old_position_embeddings_weight[:num_position_embeds_diff]
)
def get_input_embeddings(self):
return self.embeddings.word_embeddings
@ -525,6 +586,27 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
self.mlm_loss_fct = nn.CrossEntropyLoss()
def get_position_embeddings(self) -> nn.Embedding:
"""
Returns the position embeddings
"""
return self.distilbert.get_position_embeddings()
def resize_position_embeddings(self, new_num_position_embeddings: int):
"""
Resizes position embeddings of the model if :obj:`new_num_position_embeddings !=
config.max_position_embeddings`.
Arguments:
new_num_position_embeddings (:obj:`int`):
The number of new position embedding matrix. If position embeddings are learned, increasing the size
will add newly initialized vectors at the end, whereas reducing the size will remove vectors from the
end. If position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the
size will add correct vectors at the end following the position encoding algorithm, whereas reducing
the size will remove vectors from the end.
"""
self.distilbert.resize_position_embeddings(new_num_position_embeddings)
def get_output_embeddings(self):
return self.vocab_projector
@ -608,6 +690,27 @@ class DistilBertForSequenceClassification(DistilBertPreTrainedModel):
self.init_weights()
def get_position_embeddings(self) -> nn.Embedding:
"""
Returns the position embeddings
"""
return self.distilbert.get_position_embeddings()
def resize_position_embeddings(self, new_num_position_embeddings: int):
"""
Resizes position embeddings of the model if :obj:`new_num_position_embeddings !=
config.max_position_embeddings`.
Arguments:
new_num_position_embeddings (:obj:`int`):
The number of new position embedding matrix. If position embeddings are learned, increasing the size
will add newly initialized vectors at the end, whereas reducing the size will remove vectors from the
end. If position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the
size will add correct vectors at the end following the position encoding algorithm, whereas reducing
the size will remove vectors from the end.
"""
self.distilbert.resize_position_embeddings(new_num_position_embeddings)
@add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
@ -703,6 +806,27 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel):
self.init_weights()
def get_position_embeddings(self) -> nn.Embedding:
"""
Returns the position embeddings
"""
return self.distilbert.get_position_embeddings()
def resize_position_embeddings(self, new_num_position_embeddings: int):
"""
Resizes position embeddings of the model if :obj:`new_num_position_embeddings !=
config.max_position_embeddings`.
Arguments:
new_num_position_embeddings (:obj:`int`):
The number of new position embedding matrix. If position embeddings are learned, increasing the size
will add newly initialized vectors at the end, whereas reducing the size will remove vectors from the
end. If position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the
size will add correct vectors at the end following the position encoding algorithm, whereas reducing
the size will remove vectors from the end.
"""
self.distilbert.resize_position_embeddings(new_num_position_embeddings)
@add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, num_choices"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
@ -799,6 +923,27 @@ class DistilBertForTokenClassification(DistilBertPreTrainedModel):
self.init_weights()
def get_position_embeddings(self) -> nn.Embedding:
"""
Returns the position embeddings
"""
return self.distilbert.get_position_embeddings()
def resize_position_embeddings(self, new_num_position_embeddings: int):
"""
Resizes position embeddings of the model if :obj:`new_num_position_embeddings !=
config.max_position_embeddings`.
Arguments:
new_num_position_embeddings (:obj:`int`):
The number of new position embedding matrix. If position embeddings are learned, increasing the size
will add newly initialized vectors at the end, whereas reducing the size will remove vectors from the
end. If position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the
size will add correct vectors at the end following the position encoding algorithm, whereas reducing
the size will remove vectors from the end.
"""
self.distilbert.resize_position_embeddings(new_num_position_embeddings)
@add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
@ -883,6 +1028,27 @@ class DistilBertForMultipleChoice(DistilBertPreTrainedModel):
self.init_weights()
def get_position_embeddings(self) -> nn.Embedding:
"""
Returns the position embeddings
"""
return self.distilbert.get_position_embeddings()
def resize_position_embeddings(self, new_num_position_embeddings: int):
"""
Resizes position embeddings of the model if :obj:`new_num_position_embeddings !=
config.max_position_embeddings`.
Arguments:
new_num_position_embeddings (:obj:`int`)
The number of new position embeddings. If position embeddings are learned, increasing the size will add
newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If
position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will
add correct vectors at the end following the position encoding algorithm, whereas reducing the size
will remove vectors from the end.
"""
self.distilbert.resize_position_embeddings(new_num_position_embeddings)
@add_start_docstrings_to_model_forward(
DISTILBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
)

View File

@ -480,17 +480,6 @@ class PegasusPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
@property
def dummy_inputs(self):
pad_token = self.config.pad_token_id
input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
dummy_inputs = {
"attention_mask": input_ids.ne(pad_token),
"input_ids": input_ids,
"decoder_input_ids": input_ids,
}
return dummy_inputs
PEGASUS_START_DOCSTRING = r"""
This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
@ -658,6 +647,34 @@ class PegasusEncoder(PegasusPreTrainedModel):
self.init_weights()
def resize_position_embeddings(self, new_num_position_embeddings: int):
"""
Resizes position embeddings matrix of the model if :obj:`new_num_position_embeddings !=
config.max_position_embeddings`.
Arguments:
new_num_position_embeddings (:obj:`int`):
The number of new position embeddings. If position embeddings are learned, increasing the size will add
newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If
position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will
add correct vectors at the end following the position encoding algorithm, whereas reducing the size
will remove vectors from the end.
"""
logger.info(f"Setting `config.max_position_embeddings={new_num_position_embeddings}`...")
self.config.max_position_embeddings = new_num_position_embeddings
self.embed_positions = PegasusSinusoidalPositionalEmbedding(
self.config.max_position_embeddings,
self.config.d_model,
self.padding_idx,
)
def get_position_embeddings(self) -> nn.Embedding:
"""
Returns the position embeddings matrix
"""
return self.embed_positions
def forward(
self,
input_ids=None,
@ -848,6 +865,34 @@ class PegasusDecoder(PegasusPreTrainedModel):
return combined_attention_mask
def resize_position_embeddings(self, new_num_position_embeddings: int):
"""
Resizes position embeddings matrix of the model if :obj:`new_num_position_embeddings !=
config.max_position_embeddings`.
Arguments:
new_num_position_embeddings (:obj:`int`):
The number of new position embeddings. If position embeddings are learned, increasing the size will add
newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If
position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will
add correct vectors at the end following the position encoding algorithm, whereas reducing the size
will remove vectors from the end.
"""
logger.info(f"Setting `config.max_position_embeddings={new_num_position_embeddings}`...")
self.config.max_position_embeddings = new_num_position_embeddings
self.embed_positions = PegasusSinusoidalPositionalEmbedding(
self.config.max_position_embeddings,
self.config.d_model,
self.padding_idx,
)
def get_position_embeddings(self) -> nn.Embedding:
"""
Returns the position embeddings matrix
"""
return self.embed_positions
def forward(
self,
input_ids=None,
@ -1097,6 +1142,29 @@ class PegasusModel(PegasusPreTrainedModel):
def get_decoder(self):
return self.decoder
def resize_position_embeddings(self, new_num_position_embeddings: int):
"""
Resizes position embeddings matrix of the model if :obj:`new_num_position_embeddings !=
config.max_position_embeddings`.
Arguments:
new_num_position_embeddings (:obj:`int`):
The number of new position embeddings. If position embeddings are learned, increasing the size will add
newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If
position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will
add correct vectors at the end following the position encoding algorithm, whereas reducing the size
will remove vectors from the end.
"""
self.config.max_position_embeddings = new_num_position_embeddings
self.encoder.resize_position_embeddings(new_num_position_embeddings)
self.decoder.resize_position_embeddings(new_num_position_embeddings)
def get_position_embeddings(self) -> Tuple[nn.Embedding]:
"""
Returns the position embeddings matrix
"""
return (self.encoder.get_position_embeddings(), self.decoder.get_position_embeddings())
@add_start_docstrings_to_model_forward(PEGASUS_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
@ -1237,6 +1305,29 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel):
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def resize_position_embeddings(self, new_num_position_embeddings: int):
"""
Resizes position embeddings matrix of the model if :obj:`new_num_position_embeddings !=
config.max_position_embeddings`.
Arguments:
new_num_position_embeddings (:obj:`int`):
The number of new position embeddings. If position embeddings are learned, increasing the size will add
newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If
position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will
add correct vectors at the end following the position encoding algorithm, whereas reducing the size
will remove vectors from the end.
"""
self.config.max_position_embeddings = new_num_position_embeddings
self.model.encoder.resize_position_embeddings(new_num_position_embeddings)
self.model.decoder.resize_position_embeddings(new_num_position_embeddings)
def get_position_embeddings(self) -> Tuple[nn.Embedding]:
"""
Returns the position embeddings matrix
"""
return (self.model.encoder.get_position_embeddings(), self.model.decoder.get_position_embeddings())
@add_start_docstrings_to_model_forward(PEGASUS_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
@add_end_docstrings(PEGASUS_GENERATION_EXAMPLE)
@ -1373,7 +1464,6 @@ class PegasusDecoderWrapper(PegasusPreTrainedModel):
return self.decoder(*args, **kwargs)
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->Pegasus
class PegasusForCausalLM(PegasusPreTrainedModel):
def __init__(self, config):
super().__init__(config)
@ -1404,7 +1494,30 @@ class PegasusForCausalLM(PegasusPreTrainedModel):
def get_decoder(self):
return self.model.decoder
def get_position_embeddings(self) -> nn.Embedding:
"""
Returns the position embeddings matrix
"""
return self.model.decoder.get_position_embeddings()
def resize_position_embeddings(self, new_num_position_embeddings: int):
"""
Resizes position embeddings matrix of the model if :obj:`new_num_position_embeddings !=
config.max_position_embeddings`.
Arguments:
new_num_position_embeddings (:obj:`int`):
The number of new position embeddings. If position embeddings are learned, increasing the size will add
newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If
position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will
add correct vectors at the end following the position encoding algorithm, whereas reducing the size
will remove vectors from the end.
"""
self.config.max_position_embeddings = new_num_position_embeddings
self.model.decoder.resize_position_embeddings(new_num_position_embeddings)
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM.forward with Bart->Pegasus
def forward(
self,
input_ids=None,

View File

@ -94,6 +94,7 @@ class ModelTesterMixin:
test_torchscript = True
test_pruning = True
test_resize_embeddings = True
test_resize_position_embeddings = False
test_head_masking = True
test_missing_keys = True
test_model_parallel = False
@ -1067,6 +1068,85 @@ class ModelTesterMixin:
hidden_states_with_chunk = model(**self._prepare_for_class(inputs_dict, model_class))[0]
self.assertTrue(torch.allclose(hidden_states_no_chunk, hidden_states_with_chunk, atol=1e-3))
def test_resize_position_vector_embeddings(self):
if not self.test_resize_position_embeddings:
return
(
original_config,
inputs_dict,
) = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
config = copy.deepcopy(original_config)
model = model_class(config)
model.to(torch_device)
if self.model_tester.is_training is False:
model.eval()
max_position_embeddings = config.max_position_embeddings
# Retrieve the embeddings and clone theme
if model.config.is_encoder_decoder:
encoder_model_embed, decoder_model_embed = model.get_position_embeddings()
encoder_cloned_embeddings = encoder_model_embed.weight.clone()
decoder_cloned_embeddings = decoder_model_embed.weight.clone()
else:
model_embed = model.get_position_embeddings()
cloned_embeddings = model_embed.weight.clone()
# Check that resizing the position embeddings with a larger max_position_embeddings increases
# the model's postion embeddings size
model.resize_position_embeddings(max_position_embeddings + 10)
self.assertEqual(model.config.max_position_embeddings, max_position_embeddings + 10)
# Check that it actually resizes the embeddings matrix
if model.config.is_encoder_decoder:
encoder_model_embed, decoder_model_embed = model.get_position_embeddings()
self.assertEqual(encoder_model_embed.weight.shape[0], encoder_cloned_embeddings.shape[0] + 10)
self.assertEqual(decoder_model_embed.weight.shape[0], decoder_cloned_embeddings.shape[0] + 10)
else:
model_embed = model.get_position_embeddings()
self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] + 10)
# Check that the model can still do a forward pass successfully (every parameter should be resized)
model(**self._prepare_for_class(inputs_dict, model_class))
# Check that resizing the position embeddings with a smaller max_position_embeddings decreases
# the model's max_position_embeddings
model.resize_position_embeddings(max_position_embeddings - 5)
self.assertEqual(model.config.max_position_embeddings, max_position_embeddings - 5)
# Check that it actually resizes the embeddings matrix
if model.config.is_encoder_decoder:
encoder_model_embed, decoder_model_embed = model.get_position_embeddings()
self.assertEqual(encoder_model_embed.weight.shape[0], encoder_cloned_embeddings.shape[0] - 5)
self.assertEqual(decoder_model_embed.weight.shape[0], decoder_cloned_embeddings.shape[0] - 5)
else:
model_embed = model.get_position_embeddings()
self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] - 5)
# Check that the model can still do a forward pass successfully (every parameter should be resized)
model(**self._prepare_for_class(inputs_dict, model_class))
# Check that adding and removing tokens has not modified the first part of the embedding matrix.
models_equal = True
if model.config.is_encoder_decoder:
for p1, p2 in zip(encoder_cloned_embeddings, encoder_model_embed.weight):
if p1.data.ne(p2.data).sum() > 0:
models_equal = False
for p1, p2 in zip(decoder_cloned_embeddings, decoder_model_embed.weight):
if p1.data.ne(p2.data).sum() > 0:
models_equal = False
else:
for p1, p2 in zip(cloned_embeddings, model_embed.weight):
if p1.data.ne(p2.data).sum() > 0:
models_equal = False
self.assertTrue(models_equal)
def test_resize_tokens_embeddings(self):
(
original_config,

View File

@ -214,6 +214,7 @@ class DistilBertModelTest(ModelTesterMixin, unittest.TestCase):
test_torchscript = True
test_resize_embeddings = True
test_sequence_classification_problem_types = True
test_resize_position_embeddings = True
def setUp(self):
self.model_tester = DistilBertModelTester(self)

View File

@ -229,6 +229,7 @@ class PegasusModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
all_model_classes = (PegasusModel, PegasusForConditionalGeneration) if is_torch_available() else ()
all_generative_model_classes = (PegasusForConditionalGeneration,) if is_torch_available() else ()
is_encoder_decoder = True
test_resize_position_embeddings = True
test_pruning = False
test_missing_keys = False
@ -526,6 +527,7 @@ class PegasusStandaloneDecoderModelTester:
class PegasusStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (PegasusDecoder, PegasusForCausalLM) if is_torch_available() else ()
all_generative_model_classes = (PegasusForCausalLM,) if is_torch_available() else ()
test_resize_position_embeddings = True
test_pruning = False
is_encoder_decoder = False