fix repetition penalty mask in tf
This commit is contained in:
parent
b29fed790b
commit
3e624c64ca
|
@ -70,6 +70,7 @@ class PretrainedConfig(object):
|
|||
# Parameters for sequence generation
|
||||
self.max_length = kwargs.pop("max_length", 20)
|
||||
self.do_sample = kwargs.pop("do_sample", False)
|
||||
self.do_sample = kwargs.pop("early_stopping", False)
|
||||
self.num_beams = kwargs.pop("num_beams", 1)
|
||||
self.temperature = kwargs.pop("temperature", 1.0)
|
||||
self.top_k = kwargs.pop("top_k", 50)
|
||||
|
|
|
@ -460,6 +460,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||
input_ids=None,
|
||||
max_length=None,
|
||||
do_sample=True,
|
||||
early_stopping=False,
|
||||
num_beams=None,
|
||||
temperature=None,
|
||||
top_k=None,
|
||||
|
@ -559,11 +560,12 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||
if self.get_output_embeddings() is None:
|
||||
raise AttributeError(
|
||||
"You tried to generate sequences with a model that does not have a LM Head."
|
||||
"Please use another model class (e.g. `OpenAIGPTLMHeadModel`, `XLNetLMHeadModel`, `GPT2LMHeadModel`, `CTRLLMHeadModel`, `T5WithLMHeadModel`, `TransfoXLLMHeadModel`)"
|
||||
"Please use another model class (e.g. `TFOpenAIGPTLMHeadModel`, `TFXLNetLMHeadModel`, `TFGPT2LMHeadModel`, `TFCTRLLMHeadModel`, `TFT5WithLMHeadModel`, `TFTransfoXLLMHeadModel`)"
|
||||
)
|
||||
|
||||
max_length = max_length if max_length is not None else self.config.max_length
|
||||
do_sample = do_sample if do_sample is not None else self.config.do_sample
|
||||
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
|
||||
num_beams = num_beams if num_beams is not None else self.config.num_beams
|
||||
temperature = temperature if temperature is not None else self.config.temperature
|
||||
top_k = top_k if top_k is not None else self.config.top_k
|
||||
|
@ -586,6 +588,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||
|
||||
assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictely positive integer."
|
||||
assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
|
||||
assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean."
|
||||
assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictely positive integer."
|
||||
assert temperature > 0, "`temperature` should be strictely positive."
|
||||
assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer."
|
||||
|
@ -662,6 +665,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||
cur_len,
|
||||
max_length,
|
||||
do_sample,
|
||||
early_stopping,
|
||||
temperature,
|
||||
top_k,
|
||||
top_p,
|
||||
|
@ -803,6 +807,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||
cur_len,
|
||||
max_length,
|
||||
do_sample,
|
||||
early_stopping,
|
||||
temperature,
|
||||
top_k,
|
||||
top_p,
|
||||
|
@ -820,7 +825,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||
|
||||
# generated hypotheses
|
||||
generated_hyps = [
|
||||
BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=False) for _ in range(batch_size)
|
||||
BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=early_stopping)
|
||||
for _ in range(batch_size)
|
||||
]
|
||||
|
||||
# scores for each sentence in the beam
|
||||
|
@ -1058,10 +1064,11 @@ def _create_next_token_logits_penalties(input_ids, logits, repetition_penalty):
|
|||
prev_input_ids = [np.unique(input_id) for input_id in input_ids.numpy()]
|
||||
for i, prev_input_id in enumerate(prev_input_ids):
|
||||
logit_penalized = logits[i].numpy()[prev_input_id]
|
||||
logit_penalties = np.zeros(logit_penalized.shape)
|
||||
# if previous logit score is < 0 then multiply repetition penalty else divide
|
||||
logit_penalized[logit_penalized < 0] = repetition_penalty
|
||||
logit_penalized[logit_penalized > 0] = 1 / repetition_penalty
|
||||
np.put(token_penalties[i], prev_input_id, logit_penalized)
|
||||
logit_penalties[logit_penalized < 0] = repetition_penalty
|
||||
logit_penalties[logit_penalized > 0] = 1 / repetition_penalty
|
||||
np.put(token_penalties[i], prev_input_id, logit_penalties)
|
||||
return tf.convert_to_tensor(token_penalties, dtype=tf.float32)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue