FIX small bugs in `run_classifier_pytorch.py`

This commit is contained in:
VictorSanh 2018-11-02 01:12:52 -04:00
parent 936eb4c3ad
commit 1d8511f8f2
1 changed files with 1 additions and 1 deletions

View File

@ -512,7 +512,7 @@ def main():
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
model.train()
for epoch in range(args.num_train_epochs):
for epoch in range(int(args.num_train_epochs)):
for input_ids, input_mask, segment_ids, label_ids in train_dataloader:
input_ids = input_ids.to(device)
input_mask = input_mask.float().to(device)