fix tests
This commit is contained in:
parent
6a083fd447
commit
26497d1199
|
@ -262,7 +262,7 @@ class TFCommonTestCases:
|
|||
# self.assertEqual(len(params_tied_2), len(params_tied))
|
||||
|
||||
|
||||
def ids_tensor(shape, vocab_size, rng=None, name=None, dtype=tf.int32):
|
||||
def ids_tensor(shape, vocab_size, rng=None, name=None, dtype=None):
|
||||
"""Creates a random int32 tensor of the shape within the vocab size."""
|
||||
if rng is None:
|
||||
rng = random.Random()
|
||||
|
@ -275,7 +275,11 @@ def ids_tensor(shape, vocab_size, rng=None, name=None, dtype=tf.int32):
|
|||
for _ in range(total_dims):
|
||||
values.append(rng.randint(0, vocab_size - 1))
|
||||
|
||||
return tf.constant(values, shape=shape, dtype=dtype)
|
||||
output = tf.constant(values,
|
||||
shape=shape,
|
||||
dtype=dtype if dtype is not None else tf.int32)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class TFModelUtilsTest(unittest.TestCase):
|
||||
|
|
Loading…
Reference in New Issue