Fix bug train_batch_size not an int.

Division makes args.train_batch_size becoming a float.
cc @thomwolf
This commit is contained in:
VictorSanh 2018-11-04 17:19:40 -05:00
parent d55c3ae83f
commit 649e9774cd
2 changed files with 2 additions and 2 deletions

View File

@ -426,7 +426,7 @@ def main():
raise ValueError("Invalid accumulate_gradients parameter: {}, should be >= 1".format(
args.accumulate_gradients))
args.train_batch_size = args.train_batch_size / args.accumulate_gradients
args.train_batch_size = int(args.train_batch_size / args.accumulate_gradients)
random.seed(args.seed)
np.random.seed(args.seed)

View File

@ -756,7 +756,7 @@ def main():
raise ValueError("Invalid accumulate_gradients parameter: {}, should be >= 1".format(
args.accumulate_gradients))
args.train_batch_size = args.train_batch_size / args.accumulate_gradients
args.train_batch_size = int(args.train_batch_size / args.accumulate_gradients)
random.seed(args.seed)
np.random.seed(args.seed)