clean up classification model output

This commit is contained in:
thomwolf 2018-11-30 22:54:53 +01:00
parent 7f7c41b0c1
commit 89d47230d7
1 changed files with 1 additions and 1 deletions

View File

@ -546,7 +546,7 @@ def main():
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
batch = tuple(t.to(device) for t in batch)
input_ids, input_mask, segment_ids, label_ids = batch
loss, _ = model(input_ids, segment_ids, input_mask, label_ids)
loss = model(input_ids, segment_ids, input_mask, label_ids)
if n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu.
if args.fp16 and args.loss_scale != 1.0: