diff --git a/examples/pytorch/language-modeling/run_clm_no_trainer.py b/examples/pytorch/language-modeling/run_clm_no_trainer.py index 0c85e45238..f4d86a18d3 100755 --- a/examples/pytorch/language-modeling/run_clm_no_trainer.py +++ b/examples/pytorch/language-modeling/run_clm_no_trainer.py @@ -35,7 +35,7 @@ from torch.utils.data.dataloader import DataLoader from tqdm.auto import tqdm import transformers -from accelerate import Accelerator +from accelerate import Accelerator, DistributedType from transformers import ( CONFIG_MAPPING, MODEL_MAPPING, @@ -403,6 +403,10 @@ def main(): model, optimizer, train_dataloader, eval_dataloader ) + # On TPU, the tie weights in our model have been disconnected, so we need to restore the ties. + if accelerator.distributed_type == DistributedType.TPU: + model.tie_weights() + # Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be # shorter in multiprocess) diff --git a/examples/pytorch/language-modeling/run_mlm_no_trainer.py b/examples/pytorch/language-modeling/run_mlm_no_trainer.py index 6d911e93b1..d2bfdf26a1 100755 --- a/examples/pytorch/language-modeling/run_mlm_no_trainer.py +++ b/examples/pytorch/language-modeling/run_mlm_no_trainer.py @@ -35,7 +35,7 @@ from torch.utils.data.dataloader import DataLoader from tqdm.auto import tqdm import transformers -from accelerate import Accelerator +from accelerate import Accelerator, DistributedType from transformers import ( CONFIG_MAPPING, MODEL_MAPPING, @@ -448,6 +448,10 @@ def main(): model, optimizer, train_dataloader, eval_dataloader ) + # On TPU, the tie weights in our model have been disconnected, so we need to restore the ties. + if accelerator.distributed_type == DistributedType.TPU: + model.tie_weights() + # Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be # shorter in multiprocess) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 5d4d6cfd0f..208d526b2c 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -594,6 +594,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix self = getattr(self, self.base_model_prefix) self._tie_encoder_decoder_weights(self.encoder, self.decoder, self.base_model_prefix) + for module in self.modules(): + if hasattr(module, "_tie_weights"): + module._tie_weights() + @staticmethod def _tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_model_prefix: str): uninitialized_encoder_weights: List[str] = [] diff --git a/src/transformers/models/albert/modeling_albert.py b/src/transformers/models/albert/modeling_albert.py index dd07ebebd1..5f8388c57f 100755 --- a/src/transformers/models/albert/modeling_albert.py +++ b/src/transformers/models/albert/modeling_albert.py @@ -860,8 +860,6 @@ class AlbertMLMHead(nn.Module): self.dense = nn.Linear(config.hidden_size, config.embedding_size) self.decoder = nn.Linear(config.embedding_size, config.vocab_size) self.activation = ACT2FN[config.hidden_act] - - # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias def forward(self, hidden_states): @@ -874,6 +872,10 @@ class AlbertMLMHead(nn.Module): return prediction_scores + def _tie_weights(self): + # To tie those two weights if they get disconnected (on TPU or when the bias is resized) + self.bias = self.decoder.bias + class AlbertSOPHead(nn.Module): def __init__(self, config): diff --git a/src/transformers/models/bert_generation/modeling_bert_generation.py b/src/transformers/models/bert_generation/modeling_bert_generation.py index 120a48c098..75ab959171 100755 --- a/src/transformers/models/bert_generation/modeling_bert_generation.py +++ b/src/transformers/models/bert_generation/modeling_bert_generation.py @@ -430,16 +430,18 @@ class BertGenerationEncoder(BertGenerationPreTrainedModel): class BertGenerationOnlyLMHead(nn.Module): def __init__(self, config): super().__init__() - self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.decoder = nn.Linear(config.hidden_size, config.vocab_size) self.bias = nn.Parameter(torch.zeros(config.vocab_size)) - - # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias def forward(self, hidden_states): logits = self.decoder(hidden_states) return logits + def _tie_weights(self): + # To tie those two weights if they get disconnected (on TPU or when the bias is resized) + self.bias = self.decoder.bias + @add_start_docstrings( """BertGeneration Model with a `language modeling` head on top for CLM fine-tuning. """, diff --git a/src/transformers/models/ibert/modeling_ibert.py b/src/transformers/models/ibert/modeling_ibert.py index 01f6c2d014..68f0521f99 100644 --- a/src/transformers/models/ibert/modeling_ibert.py +++ b/src/transformers/models/ibert/modeling_ibert.py @@ -948,10 +948,8 @@ class IBertLMHead(nn.Module): self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.decoder = nn.Linear(config.hidden_size, config.vocab_size) self.bias = nn.Parameter(torch.zeros(config.vocab_size)) - - # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias def forward(self, features, **kwargs): @@ -964,6 +962,10 @@ class IBertLMHead(nn.Module): return x + def _tie_weights(self): + # To tie those two weights if they get disconnected (on TPU or when the bias is resized) + self.bias = self.decoder.bias + @add_start_docstrings( """ diff --git a/src/transformers/models/longformer/modeling_longformer.py b/src/transformers/models/longformer/modeling_longformer.py index 0fc0e756f3..fa38eb4898 100755 --- a/src/transformers/models/longformer/modeling_longformer.py +++ b/src/transformers/models/longformer/modeling_longformer.py @@ -1336,10 +1336,8 @@ class LongformerLMHead(nn.Module): self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.decoder = nn.Linear(config.hidden_size, config.vocab_size) self.bias = nn.Parameter(torch.zeros(config.vocab_size)) - - # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias def forward(self, features, **kwargs): @@ -1352,6 +1350,10 @@ class LongformerLMHead(nn.Module): return x + def _tie_weights(self): + # To tie those two weights if they get disconnected (on TPU or when the bias is resized) + self.bias = self.decoder.bias + class LongformerPreTrainedModel(PreTrainedModel): """ diff --git a/src/transformers/models/reformer/modeling_reformer.py b/src/transformers/models/reformer/modeling_reformer.py index 77f35b97fa..f91c4bdc96 100755 --- a/src/transformers/models/reformer/modeling_reformer.py +++ b/src/transformers/models/reformer/modeling_reformer.py @@ -1747,8 +1747,6 @@ class ReformerOnlyLMHead(nn.Module): self.chunk_size_lm_head = config.chunk_size_lm_head self.decoder = nn.Linear(2 * config.hidden_size, config.vocab_size, bias=False) self.bias = nn.Parameter(torch.zeros(config.vocab_size)) - - # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias def forward(self, hidden_states): @@ -1758,6 +1756,10 @@ class ReformerOnlyLMHead(nn.Module): hidden_states = self.decoder(hidden_states) return hidden_states + def _tie_weights(self): + # To tie those two weights if they get disconnected (on TPU or when the bias is resized) + self.bias = self.decoder.bias + class ReformerPreTrainedModel(PreTrainedModel): """ diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index 8bce52bb40..76a1a6c3d7 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -1124,10 +1124,8 @@ class RobertaLMHead(nn.Module): self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.decoder = nn.Linear(config.hidden_size, config.vocab_size) self.bias = nn.Parameter(torch.zeros(config.vocab_size)) - - # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias def forward(self, features, **kwargs): @@ -1140,6 +1138,10 @@ class RobertaLMHead(nn.Module): return x + def _tie_weights(self): + # To tie those two weights if they get disconnected (on TPU or when the bias is resized) + self.bias = self.decoder.bias + @add_start_docstrings( """ diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 3119d80716..1eb5f74e1e 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -364,7 +364,7 @@ class Trainer: self.tokenizer = tokenizer if self.place_model_on_device: - model = model.to(args.device) + self._move_model_to_device(model, args.device) # Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs if self.is_model_parallel: @@ -505,6 +505,12 @@ class Trainer: """ self.callback_handler.remove_callback(callback) + def _move_model_to_device(self, model, device): + model = model.to(device) + # Moving a model to an XLA device disconnects the tied weights, so we have to retie them. + if self.args.parallel_mode == ParallelMode.TPU and hasattr(model, "tie_weights"): + model.tie_weights() + def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None): if not self.args.remove_unused_columns: return dataset @@ -1017,7 +1023,7 @@ class Trainer: # do_train is not a reliable argument, as it might not be set and .train() still called, so # the following is a workaround: if args.fp16_full_eval and not args.do_train: - self.model = self.model.to(args.device) + self._move_model_to_device(self.model, args.device) if "model_path" in kwargs: resume_from_checkpoint = kwargs.pop("model_path") @@ -1078,7 +1084,7 @@ class Trainer: # If model was re-initialized, put it on the right device and update self.model_wrapped if model_reloaded: if self.place_model_on_device: - self.model = self.model.to(args.device) + self._move_model_to_device(self.model, args.device) self.model_wrapped = self.model # Keeping track whether we can can len() on the dataset or not