delete print and make style

This commit is contained in:
Patrick von Platen 2020-03-10 14:32:21 +01:00
parent ca1330f0b2
commit 9b8ee8cea0
1 changed files with 2 additions and 5 deletions

View File

@ -926,7 +926,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
if temperature != 1.0:
next_token_logits = next_token_logits / temperature
# calculate log softmax score
# calculate log softmax score
scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size)
# set eos token prob to zero if min_length is not reached
@ -937,9 +937,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
)
eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [batch_size, vocab_size])
scores = set_tensor_by_indices_to_value(
scores, eos_token_indices_mask, -float("inf")
)
scores = set_tensor_by_indices_to_value(scores, eos_token_indices_mask, -float("inf"))
if no_repeat_ngram_size > 0:
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
@ -992,7 +990,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
) # (batch_size, num_beams * vocab_size)
next_scores, next_tokens = tf.math.top_k(next_scores, k=2 * num_beams, sorted=True)
print(next_tokens)
assert shape_list(next_scores) == shape_list(next_tokens) == [batch_size, 2 * num_beams]