DistilBERT token type ids removed from inputs in run_squad
This commit is contained in:
parent
f24232cd1b
commit
16ce15ed4b
|
@ -207,11 +207,14 @@ def train(args, train_dataset, model, tokenizer):
|
|||
inputs = {
|
||||
"input_ids": batch[0],
|
||||
"attention_mask": batch[1],
|
||||
"token_type_ids": None if args.model_type in ["xlm", "roberta", "distilbert"] else batch[2],
|
||||
"token_type_ids": batch[2],
|
||||
"start_positions": batch[3],
|
||||
"end_positions": batch[4],
|
||||
}
|
||||
|
||||
if args.model_type in ["xlm", "roberta", "distilbert"]:
|
||||
del inputs["token_type_ids"]
|
||||
|
||||
if args.model_type in ["xlnet", "xlm"]:
|
||||
inputs.update({"cls_index": batch[5], "p_mask": batch[6]})
|
||||
if args.version_2_with_negative:
|
||||
|
@ -316,8 +319,12 @@ def evaluate(args, model, tokenizer, prefix=""):
|
|||
inputs = {
|
||||
"input_ids": batch[0],
|
||||
"attention_mask": batch[1],
|
||||
"token_type_ids": None if args.model_type in ["xlm", "roberta", "distilbert"] else batch[2],
|
||||
"token_type_ids": batch[2],
|
||||
}
|
||||
|
||||
if args.model_type in ["xlm", "roberta", "distilbert"]:
|
||||
del inputs["token_type_ids"]
|
||||
|
||||
example_indices = batch[3]
|
||||
|
||||
# XLNet and XLM use more arguments for their predictions
|
||||
|
@ -427,10 +434,14 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
|
|||
)
|
||||
|
||||
# Init features and dataset from cache if it exists
|
||||
if os.path.exists(cached_features_file) and not args.overwrite_cache and not output_examples:
|
||||
if os.path.exists(cached_features_file) and not args.overwrite_cache:
|
||||
logger.info("Loading features from cached file %s", cached_features_file)
|
||||
features_and_dataset = torch.load(cached_features_file)
|
||||
features, dataset = features_and_dataset["features"], features_and_dataset["dataset"]
|
||||
features, dataset, examples = (
|
||||
features_and_dataset["features"],
|
||||
features_and_dataset["dataset"],
|
||||
features_and_dataset["examples"],
|
||||
)
|
||||
else:
|
||||
logger.info("Creating features from dataset file at %s", input_dir)
|
||||
|
||||
|
@ -465,7 +476,7 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
|
|||
|
||||
if args.local_rank in [-1, 0]:
|
||||
logger.info("Saving features into cached file %s", cached_features_file)
|
||||
torch.save({"features": features, "dataset": dataset}, cached_features_file)
|
||||
torch.save({"features": features, "dataset": dataset, "examples": examples}, cached_features_file)
|
||||
|
||||
if args.local_rank == 0 and not evaluate:
|
||||
# Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
||||
|
@ -776,7 +787,7 @@ def main():
|
|||
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
|
||||
|
||||
# Load a trained model and vocabulary that you have fine-tuned
|
||||
model = model_class.from_pretrained(args.output_dir, force_download=True)
|
||||
model = model_class.from_pretrained(args.output_dir) # , force_download=True)
|
||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||
model.to(args.device)
|
||||
|
||||
|
@ -801,7 +812,7 @@ def main():
|
|||
for checkpoint in checkpoints:
|
||||
# Reload the model
|
||||
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
|
||||
model = model_class.from_pretrained(checkpoint, force_download=True)
|
||||
model = model_class.from_pretrained(checkpoint) # , force_download=True)
|
||||
model.to(args.device)
|
||||
|
||||
# Evaluate
|
||||
|
|
Loading…
Reference in New Issue