fix tests

This commit is contained in:
thomwolf 2019-09-18 12:17:21 +02:00
parent 6a083fd447
commit 26497d1199
1 changed files with 6 additions and 2 deletions

View File

@ -262,7 +262,7 @@ class TFCommonTestCases:
# self.assertEqual(len(params_tied_2), len(params_tied))
def ids_tensor(shape, vocab_size, rng=None, name=None, dtype=tf.int32):
def ids_tensor(shape, vocab_size, rng=None, name=None, dtype=None):
"""Creates a random int32 tensor of the shape within the vocab size."""
if rng is None:
rng = random.Random()
@ -275,7 +275,11 @@ def ids_tensor(shape, vocab_size, rng=None, name=None, dtype=tf.int32):
for _ in range(total_dims):
values.append(rng.randint(0, vocab_size - 1))
return tf.constant(values, shape=shape, dtype=dtype)
output = tf.constant(values,
shape=shape,
dtype=dtype if dtype is not None else tf.int32)
return output
class TFModelUtilsTest(unittest.TestCase):