Update examples/ner/run_ner.py to use AutoModel (#3305)

* Update examples/ner/run_ner.py to use AutoModel

* Fix missing code and apply `make style` command
This commit is contained in:
J.P Lee 2020-03-18 01:30:10 +09:00 committed by GitHub
parent e41212c715
commit 2b60a26b46
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 15 additions and 43 deletions

View File

@ -31,28 +31,15 @@ from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange
from transformers import (
ALL_PRETRAINED_MODEL_ARCHIVE_MAP,
WEIGHTS_NAME,
AdamW,
AlbertConfig,
AlbertForTokenClassification,
AlbertTokenizer,
BertConfig,
BertForTokenClassification,
BertTokenizer,
CamembertConfig,
CamembertForTokenClassification,
CamembertTokenizer,
DistilBertConfig,
DistilBertForTokenClassification,
DistilBertTokenizer,
RobertaConfig,
RobertaForTokenClassification,
RobertaTokenizer,
XLMRobertaConfig,
XLMRobertaForTokenClassification,
XLMRobertaTokenizer,
AutoConfig,
AutoModelForTokenClassification,
AutoTokenizer,
get_linear_schedule_with_warmup,
)
from transformers.modeling_auto import MODEL_MAPPING
from utils_ner import convert_examples_to_features, get_labels, read_examples_from_file
@ -64,22 +51,8 @@ except ImportError:
logger = logging.getLogger(__name__)
ALL_MODELS = sum(
(
tuple(conf.pretrained_config_archive_map.keys())
for conf in (BertConfig, RobertaConfig, DistilBertConfig, CamembertConfig, XLMRobertaConfig)
),
(),
)
MODEL_CLASSES = {
"albert": (AlbertConfig, AlbertForTokenClassification, AlbertTokenizer),
"bert": (BertConfig, BertForTokenClassification, BertTokenizer),
"roberta": (RobertaConfig, RobertaForTokenClassification, RobertaTokenizer),
"distilbert": (DistilBertConfig, DistilBertForTokenClassification, DistilBertTokenizer),
"camembert": (CamembertConfig, CamembertForTokenClassification, CamembertTokenizer),
"xlmroberta": (XLMRobertaConfig, XLMRobertaForTokenClassification, XLMRobertaTokenizer),
}
ALL_MODELS = tuple(ALL_PRETRAINED_MODEL_ARCHIVE_MAP)
MODEL_CLASSES = tuple(m.model_type for m in MODEL_MAPPING)
TOKENIZER_ARGS = ["do_lower_case", "strip_accents", "keep_accents", "use_fast"]
@ -411,7 +384,7 @@ def main():
default=None,
type=str,
required=True,
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES),
)
parser.add_argument(
"--model_name_or_path",
@ -594,8 +567,7 @@ def main():
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
args.model_type = args.model_type.lower()
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
config = config_class.from_pretrained(
config = AutoConfig.from_pretrained(
args.config_name if args.config_name else args.model_name_or_path,
num_labels=num_labels,
id2label={str(i): label for i, label in enumerate(labels)},
@ -604,12 +576,12 @@ def main():
)
tokenizer_args = {k: v for k, v in vars(args).items() if v is not None and k in TOKENIZER_ARGS}
logger.info("Tokenizer arguments: %s", tokenizer_args)
tokenizer = tokenizer_class.from_pretrained(
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
cache_dir=args.cache_dir if args.cache_dir else None,
**tokenizer_args,
)
model = model_class.from_pretrained(
model = AutoModelForTokenClassification.from_pretrained(
args.model_name_or_path,
from_tf=bool(".ckpt" in args.model_name_or_path),
config=config,
@ -650,7 +622,7 @@ def main():
# Evaluation
results = {}
if args.do_eval and args.local_rank in [-1, 0]:
tokenizer = tokenizer_class.from_pretrained(args.output_dir, **tokenizer_args)
tokenizer = AutoTokenizer.from_pretrained(args.output_dir, **tokenizer_args)
checkpoints = [args.output_dir]
if args.eval_all_checkpoints:
checkpoints = list(
@ -660,7 +632,7 @@ def main():
logger.info("Evaluate the following checkpoints: %s", checkpoints)
for checkpoint in checkpoints:
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
model = model_class.from_pretrained(checkpoint)
model = AutoModelForTokenClassification.from_pretrained(checkpoint)
model.to(args.device)
result, _ = evaluate(args, model, tokenizer, labels, pad_token_label_id, mode="dev", prefix=global_step)
if global_step:
@ -672,8 +644,8 @@ def main():
writer.write("{} = {}\n".format(key, str(results[key])))
if args.do_predict and args.local_rank in [-1, 0]:
tokenizer = tokenizer_class.from_pretrained(args.output_dir, **tokenizer_args)
model = model_class.from_pretrained(args.output_dir)
tokenizer = AutoTokenizer.from_pretrained(args.output_dir, **tokenizer_args)
model = AutoModelForTokenClassification.from_pretrained(args.output_dir)
model.to(args.device)
result, predictions = evaluate(args, model, tokenizer, labels, pad_token_label_id, mode="test")
# Save results