DistilBERT token type ids removed from inputs in run_squad

This commit is contained in:
Lysandre 2020-01-08 13:18:30 +01:00
parent f24232cd1b
commit 16ce15ed4b
1 changed files with 18 additions and 7 deletions

View File

@ -207,11 +207,14 @@ def train(args, train_dataset, model, tokenizer):
inputs = {
"input_ids": batch[0],
"attention_mask": batch[1],
"token_type_ids": None if args.model_type in ["xlm", "roberta", "distilbert"] else batch[2],
"token_type_ids": batch[2],
"start_positions": batch[3],
"end_positions": batch[4],
}
if args.model_type in ["xlm", "roberta", "distilbert"]:
del inputs["token_type_ids"]
if args.model_type in ["xlnet", "xlm"]:
inputs.update({"cls_index": batch[5], "p_mask": batch[6]})
if args.version_2_with_negative:
@ -316,8 +319,12 @@ def evaluate(args, model, tokenizer, prefix=""):
inputs = {
"input_ids": batch[0],
"attention_mask": batch[1],
"token_type_ids": None if args.model_type in ["xlm", "roberta", "distilbert"] else batch[2],
"token_type_ids": batch[2],
}
if args.model_type in ["xlm", "roberta", "distilbert"]:
del inputs["token_type_ids"]
example_indices = batch[3]
# XLNet and XLM use more arguments for their predictions
@ -427,10 +434,14 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
)
# Init features and dataset from cache if it exists
if os.path.exists(cached_features_file) and not args.overwrite_cache and not output_examples:
if os.path.exists(cached_features_file) and not args.overwrite_cache:
logger.info("Loading features from cached file %s", cached_features_file)
features_and_dataset = torch.load(cached_features_file)
features, dataset = features_and_dataset["features"], features_and_dataset["dataset"]
features, dataset, examples = (
features_and_dataset["features"],
features_and_dataset["dataset"],
features_and_dataset["examples"],
)
else:
logger.info("Creating features from dataset file at %s", input_dir)
@ -465,7 +476,7 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
if args.local_rank in [-1, 0]:
logger.info("Saving features into cached file %s", cached_features_file)
torch.save({"features": features, "dataset": dataset}, cached_features_file)
torch.save({"features": features, "dataset": dataset, "examples": examples}, cached_features_file)
if args.local_rank == 0 and not evaluate:
# Make sure only the first process in distributed training process the dataset, and the others will use the cache
@ -776,7 +787,7 @@ def main():
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
# Load a trained model and vocabulary that you have fine-tuned
model = model_class.from_pretrained(args.output_dir, force_download=True)
model = model_class.from_pretrained(args.output_dir) # , force_download=True)
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
model.to(args.device)
@ -801,7 +812,7 @@ def main():
for checkpoint in checkpoints:
# Reload the model
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
model = model_class.from_pretrained(checkpoint, force_download=True)
model = model_class.from_pretrained(checkpoint) # , force_download=True)
model.to(args.device)
# Evaluate