Fix bug train_batch_size not an int.
Division makes args.train_batch_size becoming a float. cc @thomwolf
This commit is contained in:
parent
d55c3ae83f
commit
649e9774cd
|
@ -426,7 +426,7 @@ def main():
|
|||
raise ValueError("Invalid accumulate_gradients parameter: {}, should be >= 1".format(
|
||||
args.accumulate_gradients))
|
||||
|
||||
args.train_batch_size = args.train_batch_size / args.accumulate_gradients
|
||||
args.train_batch_size = int(args.train_batch_size / args.accumulate_gradients)
|
||||
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
|
|
|
@ -756,7 +756,7 @@ def main():
|
|||
raise ValueError("Invalid accumulate_gradients parameter: {}, should be >= 1".format(
|
||||
args.accumulate_gradients))
|
||||
|
||||
args.train_batch_size = args.train_batch_size / args.accumulate_gradients
|
||||
args.train_batch_size = int(args.train_batch_size / args.accumulate_gradients)
|
||||
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
|
|
Loading…
Reference in New Issue