FIX small bugs in `run_classifier_pytorch.py`
This commit is contained in:
parent
cc228089ef
commit
936eb4c3ad
|
@ -410,8 +410,8 @@ def input_fn_builder(features, seq_length, train_batch_size):
|
|||
|
||||
input_ids_tensor = torch.tensor(all_input_ids, dtype=torch.Long)
|
||||
input_mask_tensor = torch.tensor(all_input_mask, dtype=torch.Long)
|
||||
segment_tensor = torch.tensor(all_segment, dtype=torch.Long)
|
||||
label_tensor = torch.tensor(all_label, dtype=torch.Long)
|
||||
segment_tensor = torch.tensor(all_segment_ids, dtype=torch.Long)
|
||||
label_tensor = torch.tensor(all_label_ids, dtype=torch.Long)
|
||||
|
||||
train_data = TensorDataset(input_ids_tensor, input_mask_tensor,
|
||||
segment_tensor, label_tensor)
|
||||
|
@ -512,7 +512,7 @@ def main():
|
|||
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
|
||||
|
||||
model.train()
|
||||
for epoch in args.num_train_epochs:
|
||||
for epoch in range(args.num_train_epochs):
|
||||
for input_ids, input_mask, segment_ids, label_ids in train_dataloader:
|
||||
input_ids = input_ids.to(device)
|
||||
input_mask = input_mask.float().to(device)
|
||||
|
|
Loading…
Reference in New Issue