Use torch.unique_consecutive to check same element (#13637)

We use `torch.unique` here only to check whether every elements have
the same value.
Therefore, we can use `torch.unique_consecutive` here.

This function eliminates all but the first element from every consecutive
group of equivalent elements.
Like, if we apply this function to `[1, 2, 2, 1]`, it will result in
`[1, 2, 1]`.

As you could see, this is enough for checking whether every elements
have the same value.

Since `torch.unique_consecutive` do less thing, it is much more faster.
On my computer, it is 25x faster on GPU and 15x faster on CPU.
This commit is contained in:
Tommy Chiang 2021-09-24 16:31:23 +08:00 committed by GitHub
parent 95f888fd6a
commit a2ef9c5446
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 5 additions and 5 deletions

View File

@ -1457,7 +1457,7 @@ class BartForSequenceClassification(BartPretrainedModel):
eos_mask = input_ids.eq(self.config.eos_token_id)
if len(torch.unique(eos_mask.sum(1))) > 1:
if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
raise ValueError("All examples must have the same number of <eos> tokens.")
sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[
:, -1, :

View File

@ -2668,7 +2668,7 @@ class BigBirdPegasusForSequenceClassification(BigBirdPegasusPreTrainedModel):
eos_mask = input_ids.eq(self.config.eos_token_id)
if len(torch.unique(eos_mask.sum(1))) > 1:
if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
raise ValueError("All examples must have the same number of <eos> tokens.")
sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[
:, -1, :

View File

@ -2522,7 +2522,7 @@ class LEDForSequenceClassification(LEDPreTrainedModel):
eos_mask = input_ids.eq(self.config.eos_token_id)
if len(torch.unique(eos_mask.sum(1))) > 1:
if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
raise ValueError("All examples must have the same number of <eos> tokens.")
sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[
:, -1, :

View File

@ -1463,7 +1463,7 @@ class MBartForSequenceClassification(MBartPreTrainedModel):
eos_mask = input_ids.eq(self.config.eos_token_id)
if len(torch.unique(eos_mask.sum(1))) > 1:
if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
raise ValueError("All examples must have the same number of <eos> tokens.")
sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[
:, -1, :

View File

@ -2972,7 +2972,7 @@ class {{cookiecutter.camelcase_modelname}}ForSequenceClassification({{cookiecutt
eos_mask = input_ids.eq(self.config.eos_token_id)
if len(torch.unique(eos_mask.sum(1))) > 1:
if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
raise ValueError("All examples must have the same number of <eos> tokens.")
sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[
:, -1, :