Fix past CI after #24334 (#25113)

update

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2023-07-26 15:34:42 +02:00 committed by GitHub
parent 224da5df69
commit d30cf3d02f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 1 deletions

View File

@ -725,7 +725,9 @@ class SequenceBiasLogitsProcessor(LogitsProcessor):
torch.tensor(sequence_ids[:-1], dtype=input_ids.dtype, device=input_ids.device),
).prod(dim=1)
bias[:, last_token] += torch.where(
matching_rows.bool(), sequence_bias, torch.tensor(0.0, device=input_ids.device)
matching_rows.bool(),
torch.tensor(sequence_bias, device=input_ids.device),
torch.tensor(0.0, device=input_ids.device),
)
# 5 - apply the bias to the scores