Added repetition penalty to PPLM example (#2436)

* Added repetition penalty

* Default PPLM repetition_penalty to neutral

* Minor modifications to comply with reviewer's suggestions. (j -> token_idx)

* Formatted code with `make style`
This commit is contained in:
IWillPull 2020-01-11 06:00:07 +02:00 committed by Julien Chaumond
parent e83d9f1c1d
commit a3085020ed
1 changed files with 23 additions and 1 deletions

View File

@ -344,6 +344,7 @@ def full_text_generation(
gamma=1.5,
gm_scale=0.9,
kl_scale=0.01,
repetition_penalty=1.0,
**kwargs
):
classifier, class_id = get_classifier(discrim, class_label, device)
@ -368,7 +369,14 @@ def full_text_generation(
raise Exception("Specify either a bag of words or a discriminator")
unpert_gen_tok_text, _, _ = generate_text_pplm(
model=model, tokenizer=tokenizer, context=context, device=device, length=length, sample=sample, perturb=False
model=model,
tokenizer=tokenizer,
context=context,
device=device,
length=length,
sample=sample,
perturb=False,
repetition_penalty=repetition_penalty,
)
if device == "cuda":
torch.cuda.empty_cache()
@ -401,6 +409,7 @@ def full_text_generation(
gamma=gamma,
gm_scale=gm_scale,
kl_scale=kl_scale,
repetition_penalty=repetition_penalty,
)
pert_gen_tok_texts.append(pert_gen_tok_text)
if classifier is not None:
@ -437,6 +446,7 @@ def generate_text_pplm(
gamma=1.5,
gm_scale=0.9,
kl_scale=0.01,
repetition_penalty=1.0,
):
output_so_far = None
if context:
@ -508,6 +518,13 @@ def generate_text_pplm(
pert_logits, past, pert_all_hidden = model(last, past=pert_past)
pert_logits = pert_logits[:, -1, :] / temperature # + SMALL_CONST
for token_idx in set(output_so_far[0].tolist()):
if pert_logits[0, token_idx] < 0:
pert_logits[0, token_idx] *= repetition_penalty
else:
pert_logits[0, token_idx] /= repetition_penalty
pert_probs = F.softmax(pert_logits, dim=-1)
if classifier is not None:
@ -588,6 +605,7 @@ def run_pplm_example(
seed=0,
no_cuda=False,
colorama=False,
repetition_penalty=1.0,
):
# set Random seed
torch.manual_seed(seed)
@ -655,6 +673,7 @@ def run_pplm_example(
gamma=gamma,
gm_scale=gm_scale,
kl_scale=kl_scale,
repetition_penalty=repetition_penalty,
)
# untokenize unperturbed text
@ -767,6 +786,9 @@ if __name__ == "__main__":
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--no_cuda", action="store_true", help="no cuda")
parser.add_argument("--colorama", action="store_true", help="colors keywords")
parser.add_argument(
"--repetition_penalty", type=float, default=1.0, help="Penalize repetition. More than 1.0 -> less repetition",
)
args = parser.parse_args()
run_pplm_example(**vars(args))