fix repetition penalty
This commit is contained in:
parent
c544194611
commit
ecd15667f3
|
@ -139,7 +139,7 @@ def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=
|
|||
next_token_logits = outputs[0][0, -1, :] / (temperature if temperature > 0 else 1.)
|
||||
|
||||
# reptition penalty from CTRL (https://arxiv.org/abs/1909.05858)
|
||||
for _ in set(generated):
|
||||
for _ in set(generated.view(-1).tolist()):
|
||||
next_token_logits[_] /= repetition_penalty
|
||||
|
||||
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
|
||||
|
|
Loading…
Reference in New Issue