fix repetition penalty mask in tf

This commit is contained in:
Patrick von Platen 2020-03-09 14:55:11 +01:00
parent b29fed790b
commit 3e624c64ca
2 changed files with 13 additions and 5 deletions

View File

@ -70,6 +70,7 @@ class PretrainedConfig(object):
# Parameters for sequence generation
self.max_length = kwargs.pop("max_length", 20)
self.do_sample = kwargs.pop("do_sample", False)
self.do_sample = kwargs.pop("early_stopping", False)
self.num_beams = kwargs.pop("num_beams", 1)
self.temperature = kwargs.pop("temperature", 1.0)
self.top_k = kwargs.pop("top_k", 50)

View File

@ -460,6 +460,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
input_ids=None,
max_length=None,
do_sample=True,
early_stopping=False,
num_beams=None,
temperature=None,
top_k=None,
@ -559,11 +560,12 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
if self.get_output_embeddings() is None:
raise AttributeError(
"You tried to generate sequences with a model that does not have a LM Head."
"Please use another model class (e.g. `OpenAIGPTLMHeadModel`, `XLNetLMHeadModel`, `GPT2LMHeadModel`, `CTRLLMHeadModel`, `T5WithLMHeadModel`, `TransfoXLLMHeadModel`)"
"Please use another model class (e.g. `TFOpenAIGPTLMHeadModel`, `TFXLNetLMHeadModel`, `TFGPT2LMHeadModel`, `TFCTRLLMHeadModel`, `TFT5WithLMHeadModel`, `TFTransfoXLLMHeadModel`)"
)
max_length = max_length if max_length is not None else self.config.max_length
do_sample = do_sample if do_sample is not None else self.config.do_sample
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
num_beams = num_beams if num_beams is not None else self.config.num_beams
temperature = temperature if temperature is not None else self.config.temperature
top_k = top_k if top_k is not None else self.config.top_k
@ -586,6 +588,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictely positive integer."
assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean."
assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictely positive integer."
assert temperature > 0, "`temperature` should be strictely positive."
assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer."
@ -662,6 +665,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
cur_len,
max_length,
do_sample,
early_stopping,
temperature,
top_k,
top_p,
@ -803,6 +807,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
cur_len,
max_length,
do_sample,
early_stopping,
temperature,
top_k,
top_p,
@ -820,7 +825,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
# generated hypotheses
generated_hyps = [
BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=False) for _ in range(batch_size)
BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=early_stopping)
for _ in range(batch_size)
]
# scores for each sentence in the beam
@ -1058,10 +1064,11 @@ def _create_next_token_logits_penalties(input_ids, logits, repetition_penalty):
prev_input_ids = [np.unique(input_id) for input_id in input_ids.numpy()]
for i, prev_input_id in enumerate(prev_input_ids):
logit_penalized = logits[i].numpy()[prev_input_id]
logit_penalties = np.zeros(logit_penalized.shape)
# if previous logit score is < 0 then multiply repetition penalty else divide
logit_penalized[logit_penalized < 0] = repetition_penalty
logit_penalized[logit_penalized > 0] = 1 / repetition_penalty
np.put(token_penalties[i], prev_input_id, logit_penalized)
logit_penalties[logit_penalized < 0] = repetition_penalty
logit_penalties[logit_penalized > 0] = 1 / repetition_penalty
np.put(token_penalties[i], prev_input_id, logit_penalties)
return tf.convert_to_tensor(token_penalties, dtype=tf.float32)