Added repetition penalty to PPLM example (#2436)
* Added repetition penalty * Default PPLM repetition_penalty to neutral * Minor modifications to comply with reviewer's suggestions. (j -> token_idx) * Formatted code with `make style`
This commit is contained in:
parent
e83d9f1c1d
commit
a3085020ed
|
@ -344,6 +344,7 @@ def full_text_generation(
|
|||
gamma=1.5,
|
||||
gm_scale=0.9,
|
||||
kl_scale=0.01,
|
||||
repetition_penalty=1.0,
|
||||
**kwargs
|
||||
):
|
||||
classifier, class_id = get_classifier(discrim, class_label, device)
|
||||
|
@ -368,7 +369,14 @@ def full_text_generation(
|
|||
raise Exception("Specify either a bag of words or a discriminator")
|
||||
|
||||
unpert_gen_tok_text, _, _ = generate_text_pplm(
|
||||
model=model, tokenizer=tokenizer, context=context, device=device, length=length, sample=sample, perturb=False
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
context=context,
|
||||
device=device,
|
||||
length=length,
|
||||
sample=sample,
|
||||
perturb=False,
|
||||
repetition_penalty=repetition_penalty,
|
||||
)
|
||||
if device == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
|
@ -401,6 +409,7 @@ def full_text_generation(
|
|||
gamma=gamma,
|
||||
gm_scale=gm_scale,
|
||||
kl_scale=kl_scale,
|
||||
repetition_penalty=repetition_penalty,
|
||||
)
|
||||
pert_gen_tok_texts.append(pert_gen_tok_text)
|
||||
if classifier is not None:
|
||||
|
@ -437,6 +446,7 @@ def generate_text_pplm(
|
|||
gamma=1.5,
|
||||
gm_scale=0.9,
|
||||
kl_scale=0.01,
|
||||
repetition_penalty=1.0,
|
||||
):
|
||||
output_so_far = None
|
||||
if context:
|
||||
|
@ -508,6 +518,13 @@ def generate_text_pplm(
|
|||
|
||||
pert_logits, past, pert_all_hidden = model(last, past=pert_past)
|
||||
pert_logits = pert_logits[:, -1, :] / temperature # + SMALL_CONST
|
||||
|
||||
for token_idx in set(output_so_far[0].tolist()):
|
||||
if pert_logits[0, token_idx] < 0:
|
||||
pert_logits[0, token_idx] *= repetition_penalty
|
||||
else:
|
||||
pert_logits[0, token_idx] /= repetition_penalty
|
||||
|
||||
pert_probs = F.softmax(pert_logits, dim=-1)
|
||||
|
||||
if classifier is not None:
|
||||
|
@ -588,6 +605,7 @@ def run_pplm_example(
|
|||
seed=0,
|
||||
no_cuda=False,
|
||||
colorama=False,
|
||||
repetition_penalty=1.0,
|
||||
):
|
||||
# set Random seed
|
||||
torch.manual_seed(seed)
|
||||
|
@ -655,6 +673,7 @@ def run_pplm_example(
|
|||
gamma=gamma,
|
||||
gm_scale=gm_scale,
|
||||
kl_scale=kl_scale,
|
||||
repetition_penalty=repetition_penalty,
|
||||
)
|
||||
|
||||
# untokenize unperturbed text
|
||||
|
@ -767,6 +786,9 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--no_cuda", action="store_true", help="no cuda")
|
||||
parser.add_argument("--colorama", action="store_true", help="colors keywords")
|
||||
parser.add_argument(
|
||||
"--repetition_penalty", type=float, default=1.0, help="Penalize repetition. More than 1.0 -> less repetition",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
run_pplm_example(**vars(args))
|
||||
|
|
Loading…
Reference in New Issue