update Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
224da5df69
commit
d30cf3d02f
|
@ -725,7 +725,9 @@ class SequenceBiasLogitsProcessor(LogitsProcessor):
|
|||
torch.tensor(sequence_ids[:-1], dtype=input_ids.dtype, device=input_ids.device),
|
||||
).prod(dim=1)
|
||||
bias[:, last_token] += torch.where(
|
||||
matching_rows.bool(), sequence_bias, torch.tensor(0.0, device=input_ids.device)
|
||||
matching_rows.bool(),
|
||||
torch.tensor(sequence_bias, device=input_ids.device),
|
||||
torch.tensor(0.0, device=input_ids.device),
|
||||
)
|
||||
|
||||
# 5 - apply the bias to the scores
|
||||
|
|
Loading…
Reference in New Issue