Vectorize RepetitionPenaltyLogitsProcessor to improve performance (#8598)

* refactored exisiting nested loops to vectorized implementation

* replaced explicit indexing with torch.where

* modifying score for previous input_ids only
This commit is contained in:
Binoy Dalal 2020-11-20 13:59:06 -05:00 committed by GitHub
parent 2594bd8b73
commit 29bdb88368
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 7 additions and 7 deletions

View File

@ -146,13 +146,13 @@ class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
self.penalty = penalty
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
for i in range(scores.shape[0]):
for previous_token in set(input_ids[i].tolist()):
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
if scores[i, previous_token] < 0:
scores[i, previous_token] *= self.penalty
else:
scores[i, previous_token] /= self.penalty
ranges = torch.arange(scores.shape[0])
score = scores[ranges[:, None], input_ids]
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
scores[ranges[:, None], input_ids] = score
return scores