delete print and make style
This commit is contained in:
parent
ca1330f0b2
commit
9b8ee8cea0
|
@ -926,7 +926,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||
if temperature != 1.0:
|
||||
next_token_logits = next_token_logits / temperature
|
||||
|
||||
# calculate log softmax score
|
||||
# calculate log softmax score
|
||||
scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size)
|
||||
|
||||
# set eos token prob to zero if min_length is not reached
|
||||
|
@ -937,9 +937,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||
)
|
||||
eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [batch_size, vocab_size])
|
||||
|
||||
scores = set_tensor_by_indices_to_value(
|
||||
scores, eos_token_indices_mask, -float("inf")
|
||||
)
|
||||
scores = set_tensor_by_indices_to_value(scores, eos_token_indices_mask, -float("inf"))
|
||||
|
||||
if no_repeat_ngram_size > 0:
|
||||
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
|
||||
|
@ -992,7 +990,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||
) # (batch_size, num_beams * vocab_size)
|
||||
|
||||
next_scores, next_tokens = tf.math.top_k(next_scores, k=2 * num_beams, sorted=True)
|
||||
print(next_tokens)
|
||||
|
||||
assert shape_list(next_scores) == shape_list(next_tokens) == [batch_size, 2 * num_beams]
|
||||
|
||||
|
|
Loading…
Reference in New Issue