From 1d8511f8f2c926c72146e87029d5b0acb48bd06e Mon Sep 17 00:00:00 2001 From: VictorSanh Date: Fri, 2 Nov 2018 01:12:52 -0400 Subject: [PATCH] FIX small bugs in `run_classifier_pytorch.py` --- run_classifier_pytorch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/run_classifier_pytorch.py b/run_classifier_pytorch.py index 7e5e7757ec..64f8a74717 100644 --- a/run_classifier_pytorch.py +++ b/run_classifier_pytorch.py @@ -512,7 +512,7 @@ def main(): train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) model.train() - for epoch in range(args.num_train_epochs): + for epoch in range(int(args.num_train_epochs)): for input_ids, input_mask, segment_ids, label_ids in train_dataloader: input_ids = input_ids.to(device) input_mask = input_mask.float().to(device)