fix repetition penalty

This commit is contained in:
leo-du 2019-10-17 11:04:34 -07:00 committed by Lysandre Debut
parent c544194611
commit ecd15667f3
1 changed files with 1 additions and 1 deletions

View File

@ -139,7 +139,7 @@ def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=
next_token_logits = outputs[0][0, -1, :] / (temperature if temperature > 0 else 1.)
# reptition penalty from CTRL (https://arxiv.org/abs/1909.05858)
for _ in set(generated):
for _ in set(generated.view(-1).tolist()):
next_token_logits[_] /= repetition_penalty
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)