small fix and updates to readme

This commit is contained in:
thomwolf 2019-06-19 09:38:38 +02:00
parent f7e2ac01ea
commit 68ab9599ce
5 changed files with 53 additions and 18 deletions

View File

@ -1322,12 +1322,14 @@ python $SQUAD_DIR/evaluate-v1.1.py $SQUAD_DIR/dev-v1.1.json /tmp/debug_squad/pre
{"f1": 88.52381567990474, "exact_match": 81.22043519394512}
```
Here is an example using distributed training on 8 V100 GPUs and Bert Whole Word Masking model to reach a F1 > 93 on SQuAD:
**distributed training**
Here is an example using distributed training on 8 V100 GPUs and Bert Whole Word Masking uncased model to reach a F1 > 93 on SQuAD:
```bash
python -m torch.distributed.launch --nproc_per_node=8 \
run_squad.py \
--bert_model bert-large-cased-whole-word-masking \
--bert_model bert-large-uncased-whole-word-masking \
--do_train \
--do_predict \
--do_lower_case \
@ -1337,17 +1339,31 @@ python -m torch.distributed.launch --nproc_per_node=8 \
--num_train_epochs 2 \
--max_seq_length 384 \
--doc_stride 128 \
--output_dir ../models/train_squad_large_cased_wwm/ \
--output_dir ../models/wwm_uncased_finetuned_squad/ \
--train_batch_size 24 \
--gradient_accumulation_steps 12
```
Training with these hyper-parameters gave us the following results:
```bash
python $SQUAD_DIR/evaluate-v1.1.py $SQUAD_DIR/dev-v1.1.json ../models/train_squad_large_cased_wwm/predictions.json
python $SQUAD_DIR/evaluate-v1.1.py $SQUAD_DIR/dev-v1.1.json ../models/wwm_uncased_finetuned_squad/predictions.json
{"exact_match": 86.91579943235573, "f1": 93.1532499015869}
```
This is the model provided as `bert-large-uncased-whole-word-masking-finetuned-squad`.
And here is the model provided as `bert-large-cased-whole-word-masking-finetuned-squad`:
```bash
python -m torch.distributed.launch --nproc_per_node=8 run_squad.py --bert_model bert-large-cased-whole-word-masking --do_train --do_predict --do_lower_case --train_file $SQUAD_DIR/train-v1.1.json --predict_file $SQUAD_DIR/dev-v1.1.json --learning_rate 3e-5 --num_train_epochs 2 --max_seq_length 384 --doc_stride 128 --output_dir ../models/wwm_cased_finetuned_squad/ --train_batch_size 24 --gradient_accumulation_steps 12
```
Training with these hyper-parameters gave us the following results:
```bash
python $SQUAD_DIR/evaluate-v1.1.py $SQUAD_DIR/dev-v1.1.json ../models/wwm_uncased_finetuned_squad/predictions.json
{"exact_match": 84.18164616840113, "f1": 91.58645594850135}
```
#### SWAG
The data for SWAG can be downloaded by cloning the following [repository](https://github.com/rowanz/swagaf)

View File

@ -8,7 +8,7 @@ import torch
import torch.nn.functional as F
import numpy as np
from pytorch_pretrained_bert import BertModel, BertTokenizer
from pytorch_pretrained_bert import BertForSequenceClassification, BertTokenizer
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S',
@ -17,24 +17,33 @@ logger = logging.getLogger(__name__)
def run_model():
parser = argparse.ArgumentParser()
parser.add_argument('--model_name_or_path', type=str, default='bert-base-uncased',
help='pretrained model name or path to local checkpoint')
parser.add_argument('--model_name_or_path', type=str, default='bert-base-uncased', help='pretrained model name or path to local checkpoint')
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus")
parser.add_argument("--no_cuda", action='store_true', help="Whether not to use CUDA when available")
args = parser.parse_args()
print(args)
if args.batch_size == -1:
args.batch_size = 1
assert args.nsamples % args.batch_size == 0
np.random.seed(args.seed)
torch.random.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if args.local_rank == -1 or args.no_cuda:
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
n_gpu = torch.cuda.device_count()
else:
torch.cuda.set_device(args.local_rank)
args.device = torch.device("cuda", args.local_rank)
n_gpu = 1
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
torch.distributed.init_process_group(backend='nccl')
logging.basicConfig(level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
args.device, n_gpu, bool(args.local_rank != -1), args.fp16))
tokenizer = BertTokenizer.from_pretrained(args.model_name_or_path)
model = BertModel.from_pretrained(args.model_name_or_path)
model.to(device)
model = BertForSequenceClassification.from_pretrained(args.model_name_or_path)
model.to(args.device)
model.eval()

View File

@ -187,7 +187,7 @@ def main():
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
if not os.path.exists(args.output_dir):
if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
os.makedirs(args.output_dir)
task_name = args.task_name.lower()
@ -361,6 +361,10 @@ def main():
# Load a trained model and vocabulary that you have fine-tuned
model = BertForSequenceClassification.from_pretrained(args.output_dir, num_labels=num_labels)
tokenizer = BertTokenizer.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
# Good practice: save your training arguments together with the trained model
output_args_file = os.path.join(args.output_dir, 'training_args.bin')
torch.save(args, output_args_file)
else:
model = BertForSequenceClassification.from_pretrained(args.bert_model)

View File

@ -331,6 +331,10 @@ def main():
# Load a trained model and vocabulary that you have fine-tuned
model = BertForQuestionAnswering.from_pretrained(args.output_dir)
tokenizer = BertTokenizer.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
# Good practice: save your training arguments together with the trained model
output_args_file = os.path.join(args.output_dir, 'training_args.bin')
torch.save(args, output_args_file)
else:
model = BertForQuestionAnswering.from_pretrained(args.bert_model)

View File

@ -46,8 +46,7 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {
'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-pytorch_model.bin",
'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-pytorch_model.bin",
'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-pytorch_model.bin",
'bert-base-uncased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-finetuned-mrpc-pytorch_model.bin",
'bert-large-uncased-whole-word-masking-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-mrpc-pytorch_model.bin",
'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-pytorch_model.bin",
}
PRETRAINED_CONFIG_ARCHIVE_MAP = {
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json",
@ -60,6 +59,9 @@ PRETRAINED_CONFIG_ARCHIVE_MAP = {
'bert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json",
'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json",
'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json",
'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json",
'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json",
'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json",
}
BERT_CONFIG_NAME = 'bert_config.json'
TF_WEIGHTS_NAME = 'model.ckpt'