Use torch.unique_consecutive to check same element (#13637)
We use `torch.unique` here only to check whether every elements have the same value. Therefore, we can use `torch.unique_consecutive` here. This function eliminates all but the first element from every consecutive group of equivalent elements. Like, if we apply this function to `[1, 2, 2, 1]`, it will result in `[1, 2, 1]`. As you could see, this is enough for checking whether every elements have the same value. Since `torch.unique_consecutive` do less thing, it is much more faster. On my computer, it is 25x faster on GPU and 15x faster on CPU.
This commit is contained in:
parent
95f888fd6a
commit
a2ef9c5446
|
@ -1457,7 +1457,7 @@ class BartForSequenceClassification(BartPretrainedModel):
|
|||
|
||||
eos_mask = input_ids.eq(self.config.eos_token_id)
|
||||
|
||||
if len(torch.unique(eos_mask.sum(1))) > 1:
|
||||
if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
|
||||
raise ValueError("All examples must have the same number of <eos> tokens.")
|
||||
sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[
|
||||
:, -1, :
|
||||
|
|
|
@ -2668,7 +2668,7 @@ class BigBirdPegasusForSequenceClassification(BigBirdPegasusPreTrainedModel):
|
|||
|
||||
eos_mask = input_ids.eq(self.config.eos_token_id)
|
||||
|
||||
if len(torch.unique(eos_mask.sum(1))) > 1:
|
||||
if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
|
||||
raise ValueError("All examples must have the same number of <eos> tokens.")
|
||||
sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[
|
||||
:, -1, :
|
||||
|
|
|
@ -2522,7 +2522,7 @@ class LEDForSequenceClassification(LEDPreTrainedModel):
|
|||
|
||||
eos_mask = input_ids.eq(self.config.eos_token_id)
|
||||
|
||||
if len(torch.unique(eos_mask.sum(1))) > 1:
|
||||
if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
|
||||
raise ValueError("All examples must have the same number of <eos> tokens.")
|
||||
sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[
|
||||
:, -1, :
|
||||
|
|
|
@ -1463,7 +1463,7 @@ class MBartForSequenceClassification(MBartPreTrainedModel):
|
|||
|
||||
eos_mask = input_ids.eq(self.config.eos_token_id)
|
||||
|
||||
if len(torch.unique(eos_mask.sum(1))) > 1:
|
||||
if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
|
||||
raise ValueError("All examples must have the same number of <eos> tokens.")
|
||||
sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[
|
||||
:, -1, :
|
||||
|
|
|
@ -2972,7 +2972,7 @@ class {{cookiecutter.camelcase_modelname}}ForSequenceClassification({{cookiecutt
|
|||
|
||||
eos_mask = input_ids.eq(self.config.eos_token_id)
|
||||
|
||||
if len(torch.unique(eos_mask.sum(1))) > 1:
|
||||
if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
|
||||
raise ValueError("All examples must have the same number of <eos> tokens.")
|
||||
sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[
|
||||
:, -1, :
|
||||
|
|
Loading…
Reference in New Issue