Tpu tie weights (#13030)

* Fix tied weights on TPU

* Manually tie weights in no trainer examples

* Fix for test

* One last missing

* Gettning owned by my scripts

* Address review comments

* Fix test

* Fix tests

* Fix reformer tests
This commit is contained in:
Sylvain Gugger 2021-08-06 20:41:39 +02:00 committed by GitHub
parent 1bf38611a4
commit 7fcee113c1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 51 additions and 21 deletions

View File

@ -35,7 +35,7 @@ from torch.utils.data.dataloader import DataLoader
from tqdm.auto import tqdm
import transformers
from accelerate import Accelerator
from accelerate import Accelerator, DistributedType
from transformers import (
CONFIG_MAPPING,
MODEL_MAPPING,
@ -403,6 +403,10 @@ def main():
model, optimizer, train_dataloader, eval_dataloader
)
# On TPU, the tie weights in our model have been disconnected, so we need to restore the ties.
if accelerator.distributed_type == DistributedType.TPU:
model.tie_weights()
# Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be
# shorter in multiprocess)

View File

@ -35,7 +35,7 @@ from torch.utils.data.dataloader import DataLoader
from tqdm.auto import tqdm
import transformers
from accelerate import Accelerator
from accelerate import Accelerator, DistributedType
from transformers import (
CONFIG_MAPPING,
MODEL_MAPPING,
@ -448,6 +448,10 @@ def main():
model, optimizer, train_dataloader, eval_dataloader
)
# On TPU, the tie weights in our model have been disconnected, so we need to restore the ties.
if accelerator.distributed_type == DistributedType.TPU:
model.tie_weights()
# Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be
# shorter in multiprocess)

View File

@ -594,6 +594,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
self = getattr(self, self.base_model_prefix)
self._tie_encoder_decoder_weights(self.encoder, self.decoder, self.base_model_prefix)
for module in self.modules():
if hasattr(module, "_tie_weights"):
module._tie_weights()
@staticmethod
def _tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_model_prefix: str):
uninitialized_encoder_weights: List[str] = []

View File

@ -860,8 +860,6 @@ class AlbertMLMHead(nn.Module):
self.dense = nn.Linear(config.hidden_size, config.embedding_size)
self.decoder = nn.Linear(config.embedding_size, config.vocab_size)
self.activation = ACT2FN[config.hidden_act]
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def forward(self, hidden_states):
@ -874,6 +872,10 @@ class AlbertMLMHead(nn.Module):
return prediction_scores
def _tie_weights(self):
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias
class AlbertSOPHead(nn.Module):
def __init__(self, config):

View File

@ -430,16 +430,18 @@ class BertGenerationEncoder(BertGenerationPreTrainedModel):
class BertGenerationOnlyLMHead(nn.Module):
def __init__(self, config):
super().__init__()
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def forward(self, hidden_states):
logits = self.decoder(hidden_states)
return logits
def _tie_weights(self):
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias
@add_start_docstrings(
"""BertGeneration Model with a `language modeling` head on top for CLM fine-tuning. """,

View File

@ -948,10 +948,8 @@ class IBertLMHead(nn.Module):
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def forward(self, features, **kwargs):
@ -964,6 +962,10 @@ class IBertLMHead(nn.Module):
return x
def _tie_weights(self):
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias
@add_start_docstrings(
"""

View File

@ -1336,10 +1336,8 @@ class LongformerLMHead(nn.Module):
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def forward(self, features, **kwargs):
@ -1352,6 +1350,10 @@ class LongformerLMHead(nn.Module):
return x
def _tie_weights(self):
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias
class LongformerPreTrainedModel(PreTrainedModel):
"""

View File

@ -1747,8 +1747,6 @@ class ReformerOnlyLMHead(nn.Module):
self.chunk_size_lm_head = config.chunk_size_lm_head
self.decoder = nn.Linear(2 * config.hidden_size, config.vocab_size, bias=False)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def forward(self, hidden_states):
@ -1758,6 +1756,10 @@ class ReformerOnlyLMHead(nn.Module):
hidden_states = self.decoder(hidden_states)
return hidden_states
def _tie_weights(self):
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias
class ReformerPreTrainedModel(PreTrainedModel):
"""

View File

@ -1124,10 +1124,8 @@ class RobertaLMHead(nn.Module):
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def forward(self, features, **kwargs):
@ -1140,6 +1138,10 @@ class RobertaLMHead(nn.Module):
return x
def _tie_weights(self):
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias
@add_start_docstrings(
"""

View File

@ -364,7 +364,7 @@ class Trainer:
self.tokenizer = tokenizer
if self.place_model_on_device:
model = model.to(args.device)
self._move_model_to_device(model, args.device)
# Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs
if self.is_model_parallel:
@ -505,6 +505,12 @@ class Trainer:
"""
self.callback_handler.remove_callback(callback)
def _move_model_to_device(self, model, device):
model = model.to(device)
# Moving a model to an XLA device disconnects the tied weights, so we have to retie them.
if self.args.parallel_mode == ParallelMode.TPU and hasattr(model, "tie_weights"):
model.tie_weights()
def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
if not self.args.remove_unused_columns:
return dataset
@ -1017,7 +1023,7 @@ class Trainer:
# do_train is not a reliable argument, as it might not be set and .train() still called, so
# the following is a workaround:
if args.fp16_full_eval and not args.do_train:
self.model = self.model.to(args.device)
self._move_model_to_device(self.model, args.device)
if "model_path" in kwargs:
resume_from_checkpoint = kwargs.pop("model_path")
@ -1078,7 +1084,7 @@ class Trainer:
# If model was re-initialized, put it on the right device and update self.model_wrapped
if model_reloaded:
if self.place_model_on_device:
self.model = self.model.to(args.device)
self._move_model_to_device(self.model, args.device)
self.model_wrapped = self.model
# Keeping track whether we can can len() on the dataset or not