Tpu tie weights (#13030)
* Fix tied weights on TPU * Manually tie weights in no trainer examples * Fix for test * One last missing * Gettning owned by my scripts * Address review comments * Fix test * Fix tests * Fix reformer tests
This commit is contained in:
parent
1bf38611a4
commit
7fcee113c1
|
@ -35,7 +35,7 @@ from torch.utils.data.dataloader import DataLoader
|
|||
from tqdm.auto import tqdm
|
||||
|
||||
import transformers
|
||||
from accelerate import Accelerator
|
||||
from accelerate import Accelerator, DistributedType
|
||||
from transformers import (
|
||||
CONFIG_MAPPING,
|
||||
MODEL_MAPPING,
|
||||
|
@ -403,6 +403,10 @@ def main():
|
|||
model, optimizer, train_dataloader, eval_dataloader
|
||||
)
|
||||
|
||||
# On TPU, the tie weights in our model have been disconnected, so we need to restore the ties.
|
||||
if accelerator.distributed_type == DistributedType.TPU:
|
||||
model.tie_weights()
|
||||
|
||||
# Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be
|
||||
# shorter in multiprocess)
|
||||
|
||||
|
|
|
@ -35,7 +35,7 @@ from torch.utils.data.dataloader import DataLoader
|
|||
from tqdm.auto import tqdm
|
||||
|
||||
import transformers
|
||||
from accelerate import Accelerator
|
||||
from accelerate import Accelerator, DistributedType
|
||||
from transformers import (
|
||||
CONFIG_MAPPING,
|
||||
MODEL_MAPPING,
|
||||
|
@ -448,6 +448,10 @@ def main():
|
|||
model, optimizer, train_dataloader, eval_dataloader
|
||||
)
|
||||
|
||||
# On TPU, the tie weights in our model have been disconnected, so we need to restore the ties.
|
||||
if accelerator.distributed_type == DistributedType.TPU:
|
||||
model.tie_weights()
|
||||
|
||||
# Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be
|
||||
# shorter in multiprocess)
|
||||
|
||||
|
|
|
@ -594,6 +594,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
self = getattr(self, self.base_model_prefix)
|
||||
self._tie_encoder_decoder_weights(self.encoder, self.decoder, self.base_model_prefix)
|
||||
|
||||
for module in self.modules():
|
||||
if hasattr(module, "_tie_weights"):
|
||||
module._tie_weights()
|
||||
|
||||
@staticmethod
|
||||
def _tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_model_prefix: str):
|
||||
uninitialized_encoder_weights: List[str] = []
|
||||
|
|
|
@ -860,8 +860,6 @@ class AlbertMLMHead(nn.Module):
|
|||
self.dense = nn.Linear(config.hidden_size, config.embedding_size)
|
||||
self.decoder = nn.Linear(config.embedding_size, config.vocab_size)
|
||||
self.activation = ACT2FN[config.hidden_act]
|
||||
|
||||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def forward(self, hidden_states):
|
||||
|
@ -874,6 +872,10 @@ class AlbertMLMHead(nn.Module):
|
|||
|
||||
return prediction_scores
|
||||
|
||||
def _tie_weights(self):
|
||||
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
|
||||
self.bias = self.decoder.bias
|
||||
|
||||
|
||||
class AlbertSOPHead(nn.Module):
|
||||
def __init__(self, config):
|
||||
|
|
|
@ -430,16 +430,18 @@ class BertGenerationEncoder(BertGenerationPreTrainedModel):
|
|||
class BertGenerationOnlyLMHead(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
|
||||
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
||||
|
||||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def forward(self, hidden_states):
|
||||
logits = self.decoder(hidden_states)
|
||||
return logits
|
||||
|
||||
def _tie_weights(self):
|
||||
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
|
||||
self.bias = self.decoder.bias
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""BertGeneration Model with a `language modeling` head on top for CLM fine-tuning. """,
|
||||
|
|
|
@ -948,10 +948,8 @@ class IBertLMHead(nn.Module):
|
|||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
|
||||
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
||||
|
||||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def forward(self, features, **kwargs):
|
||||
|
@ -964,6 +962,10 @@ class IBertLMHead(nn.Module):
|
|||
|
||||
return x
|
||||
|
||||
def _tie_weights(self):
|
||||
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
|
||||
self.bias = self.decoder.bias
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
|
|
|
@ -1336,10 +1336,8 @@ class LongformerLMHead(nn.Module):
|
|||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
|
||||
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
||||
|
||||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def forward(self, features, **kwargs):
|
||||
|
@ -1352,6 +1350,10 @@ class LongformerLMHead(nn.Module):
|
|||
|
||||
return x
|
||||
|
||||
def _tie_weights(self):
|
||||
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
|
||||
self.bias = self.decoder.bias
|
||||
|
||||
|
||||
class LongformerPreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
|
|
|
@ -1747,8 +1747,6 @@ class ReformerOnlyLMHead(nn.Module):
|
|||
self.chunk_size_lm_head = config.chunk_size_lm_head
|
||||
self.decoder = nn.Linear(2 * config.hidden_size, config.vocab_size, bias=False)
|
||||
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
||||
|
||||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def forward(self, hidden_states):
|
||||
|
@ -1758,6 +1756,10 @@ class ReformerOnlyLMHead(nn.Module):
|
|||
hidden_states = self.decoder(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def _tie_weights(self):
|
||||
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
|
||||
self.bias = self.decoder.bias
|
||||
|
||||
|
||||
class ReformerPreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
|
|
|
@ -1124,10 +1124,8 @@ class RobertaLMHead(nn.Module):
|
|||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
|
||||
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
||||
|
||||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def forward(self, features, **kwargs):
|
||||
|
@ -1140,6 +1138,10 @@ class RobertaLMHead(nn.Module):
|
|||
|
||||
return x
|
||||
|
||||
def _tie_weights(self):
|
||||
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
|
||||
self.bias = self.decoder.bias
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
|
|
|
@ -364,7 +364,7 @@ class Trainer:
|
|||
self.tokenizer = tokenizer
|
||||
|
||||
if self.place_model_on_device:
|
||||
model = model.to(args.device)
|
||||
self._move_model_to_device(model, args.device)
|
||||
|
||||
# Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs
|
||||
if self.is_model_parallel:
|
||||
|
@ -505,6 +505,12 @@ class Trainer:
|
|||
"""
|
||||
self.callback_handler.remove_callback(callback)
|
||||
|
||||
def _move_model_to_device(self, model, device):
|
||||
model = model.to(device)
|
||||
# Moving a model to an XLA device disconnects the tied weights, so we have to retie them.
|
||||
if self.args.parallel_mode == ParallelMode.TPU and hasattr(model, "tie_weights"):
|
||||
model.tie_weights()
|
||||
|
||||
def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
|
||||
if not self.args.remove_unused_columns:
|
||||
return dataset
|
||||
|
@ -1017,7 +1023,7 @@ class Trainer:
|
|||
# do_train is not a reliable argument, as it might not be set and .train() still called, so
|
||||
# the following is a workaround:
|
||||
if args.fp16_full_eval and not args.do_train:
|
||||
self.model = self.model.to(args.device)
|
||||
self._move_model_to_device(self.model, args.device)
|
||||
|
||||
if "model_path" in kwargs:
|
||||
resume_from_checkpoint = kwargs.pop("model_path")
|
||||
|
@ -1078,7 +1084,7 @@ class Trainer:
|
|||
# If model was re-initialized, put it on the right device and update self.model_wrapped
|
||||
if model_reloaded:
|
||||
if self.place_model_on_device:
|
||||
self.model = self.model.to(args.device)
|
||||
self._move_model_to_device(self.model, args.device)
|
||||
self.model_wrapped = self.model
|
||||
|
||||
# Keeping track whether we can can len() on the dataset or not
|
||||
|
|
Loading…
Reference in New Issue