Optimize Token Classification models for TPU (#13096)
* Optimize Token Classification models for TPU As per the XLA document XLA cannot handle masked indexing well. So token classification models for BERT and others use an implementation based on `torch.where`. This implementation works well on TPU. ALBERT token classification model uses the masked indexing which causes performance issues on TPU. This PR fixes this issue by following the BERT implementation. * Same fix for ELECTRA * Same fix for LayoutLM
This commit is contained in:
parent
e02ed0ee7e
commit
eae7a96b7d
|
@ -1150,8 +1150,10 @@ class AlbertForTokenClassification(AlbertPreTrainedModel):
|
|||
# Only keep active parts of the loss
|
||||
if attention_mask is not None:
|
||||
active_loss = attention_mask.view(-1) == 1
|
||||
active_logits = logits.view(-1, self.num_labels)[active_loss]
|
||||
active_labels = labels.view(-1)[active_loss]
|
||||
active_logits = logits.view(-1, self.num_labels)
|
||||
active_labels = torch.where(
|
||||
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
|
||||
)
|
||||
loss = loss_fct(active_logits, active_labels)
|
||||
else:
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
|
|
|
@ -1259,8 +1259,10 @@ class ElectraForTokenClassification(ElectraPreTrainedModel):
|
|||
# Only keep active parts of the loss
|
||||
if attention_mask is not None:
|
||||
active_loss = attention_mask.view(-1) == 1
|
||||
active_logits = logits.view(-1, self.config.num_labels)[active_loss]
|
||||
active_labels = labels.view(-1)[active_loss]
|
||||
active_logits = logits.view(-1, self.config.num_labels)
|
||||
active_labels = torch.where(
|
||||
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
|
||||
)
|
||||
loss = loss_fct(active_logits, active_labels)
|
||||
else:
|
||||
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
|
||||
|
|
|
@ -1173,8 +1173,10 @@ class LayoutLMForTokenClassification(LayoutLMPreTrainedModel):
|
|||
|
||||
if attention_mask is not None:
|
||||
active_loss = attention_mask.view(-1) == 1
|
||||
active_logits = logits.view(-1, self.num_labels)[active_loss]
|
||||
active_labels = labels.view(-1)[active_loss]
|
||||
active_logits = logits.view(-1, self.num_labels)
|
||||
active_labels = torch.where(
|
||||
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
|
||||
)
|
||||
loss = loss_fct(active_logits, active_labels)
|
||||
else:
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
|
|
Loading…
Reference in New Issue