From 2c5d993ba48841575d9c58f0754bca00b288431c Mon Sep 17 00:00:00 2001 From: thomwolf Date: Thu, 8 Nov 2018 21:22:22 +0100 Subject: [PATCH] update readme - fix SQuAD model on multi-GPU --- README.md | 5 +++++ modeling.py | 8 +++++--- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 41145da8eb..78e88ab5ea 100644 --- a/README.md +++ b/README.md @@ -194,3 +194,8 @@ python run_squad.py \ --doc_stride 128 \ --output_dir ../debug_squad/ ``` + +Training with the previous hyper-parameters and a batch size 32 (on 4 GPUs) for 2 epochs gave us the following results: +```bash +{"f1": 88.19829549714827, "exact_match": 80.75685903500474} +``` diff --git a/modeling.py b/modeling.py index 433ee2054c..43db3b30fb 100644 --- a/modeling.py +++ b/modeling.py @@ -455,9 +455,11 @@ class BertForQuestionAnswering(nn.Module): end_logits = end_logits.squeeze(-1) if start_positions is not None and end_positions is not None: - # If we are on multi-GPU, split add a dimension - if not this is a no-op - start_positions = start_positions.squeeze(-1) - end_positions = end_positions.squeeze(-1) + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) # sometimes the start/end positions are outside our model inputs, we ignore these terms ignored_index = start_logits.size(1) start_positions.clamp_(0, ignored_index)