From 16ce15ed4bd0865d24a94aa839a44cf0f400ef50 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Wed, 8 Jan 2020 13:18:30 +0100 Subject: [PATCH] DistilBERT token type ids removed from inputs in run_squad --- examples/run_squad.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/examples/run_squad.py b/examples/run_squad.py index 6595f0464d..0a621f9ee0 100644 --- a/examples/run_squad.py +++ b/examples/run_squad.py @@ -207,11 +207,14 @@ def train(args, train_dataset, model, tokenizer): inputs = { "input_ids": batch[0], "attention_mask": batch[1], - "token_type_ids": None if args.model_type in ["xlm", "roberta", "distilbert"] else batch[2], + "token_type_ids": batch[2], "start_positions": batch[3], "end_positions": batch[4], } + if args.model_type in ["xlm", "roberta", "distilbert"]: + del inputs["token_type_ids"] + if args.model_type in ["xlnet", "xlm"]: inputs.update({"cls_index": batch[5], "p_mask": batch[6]}) if args.version_2_with_negative: @@ -316,8 +319,12 @@ def evaluate(args, model, tokenizer, prefix=""): inputs = { "input_ids": batch[0], "attention_mask": batch[1], - "token_type_ids": None if args.model_type in ["xlm", "roberta", "distilbert"] else batch[2], + "token_type_ids": batch[2], } + + if args.model_type in ["xlm", "roberta", "distilbert"]: + del inputs["token_type_ids"] + example_indices = batch[3] # XLNet and XLM use more arguments for their predictions @@ -427,10 +434,14 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal ) # Init features and dataset from cache if it exists - if os.path.exists(cached_features_file) and not args.overwrite_cache and not output_examples: + if os.path.exists(cached_features_file) and not args.overwrite_cache: logger.info("Loading features from cached file %s", cached_features_file) features_and_dataset = torch.load(cached_features_file) - features, dataset = features_and_dataset["features"], features_and_dataset["dataset"] + features, dataset, examples = ( + features_and_dataset["features"], + features_and_dataset["dataset"], + features_and_dataset["examples"], + ) else: logger.info("Creating features from dataset file at %s", input_dir) @@ -465,7 +476,7 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal if args.local_rank in [-1, 0]: logger.info("Saving features into cached file %s", cached_features_file) - torch.save({"features": features, "dataset": dataset}, cached_features_file) + torch.save({"features": features, "dataset": dataset, "examples": examples}, cached_features_file) if args.local_rank == 0 and not evaluate: # Make sure only the first process in distributed training process the dataset, and the others will use the cache @@ -776,7 +787,7 @@ def main(): torch.save(args, os.path.join(args.output_dir, "training_args.bin")) # Load a trained model and vocabulary that you have fine-tuned - model = model_class.from_pretrained(args.output_dir, force_download=True) + model = model_class.from_pretrained(args.output_dir) # , force_download=True) tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case) model.to(args.device) @@ -801,7 +812,7 @@ def main(): for checkpoint in checkpoints: # Reload the model global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else "" - model = model_class.from_pretrained(checkpoint, force_download=True) + model = model_class.from_pretrained(checkpoint) # , force_download=True) model.to(args.device) # Evaluate