diff --git a/src/transformers/models/albert/modeling_albert.py b/src/transformers/models/albert/modeling_albert.py index 395081eded..115c7bd3e4 100755 --- a/src/transformers/models/albert/modeling_albert.py +++ b/src/transformers/models/albert/modeling_albert.py @@ -1150,8 +1150,10 @@ class AlbertForTokenClassification(AlbertPreTrainedModel): # Only keep active parts of the loss if attention_mask is not None: active_loss = attention_mask.view(-1) == 1 - active_logits = logits.view(-1, self.num_labels)[active_loss] - active_labels = labels.view(-1)[active_loss] + active_logits = logits.view(-1, self.num_labels) + active_labels = torch.where( + active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels) + ) loss = loss_fct(active_logits, active_labels) else: loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index 4a5b03abdd..867a7c0915 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -1259,8 +1259,10 @@ class ElectraForTokenClassification(ElectraPreTrainedModel): # Only keep active parts of the loss if attention_mask is not None: active_loss = attention_mask.view(-1) == 1 - active_logits = logits.view(-1, self.config.num_labels)[active_loss] - active_labels = labels.view(-1)[active_loss] + active_logits = logits.view(-1, self.config.num_labels) + active_labels = torch.where( + active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels) + ) loss = loss_fct(active_logits, active_labels) else: loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) diff --git a/src/transformers/models/layoutlm/modeling_layoutlm.py b/src/transformers/models/layoutlm/modeling_layoutlm.py index 918788b447..3e7dfe8560 100644 --- a/src/transformers/models/layoutlm/modeling_layoutlm.py +++ b/src/transformers/models/layoutlm/modeling_layoutlm.py @@ -1173,8 +1173,10 @@ class LayoutLMForTokenClassification(LayoutLMPreTrainedModel): if attention_mask is not None: active_loss = attention_mask.view(-1) == 1 - active_logits = logits.view(-1, self.num_labels)[active_loss] - active_labels = labels.view(-1)[active_loss] + active_logits = logits.view(-1, self.num_labels) + active_labels = torch.where( + active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels) + ) loss = loss_fct(active_logits, active_labels) else: loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))