Update examples/ner/run_ner.py to use AutoModel (#3305)
* Update examples/ner/run_ner.py to use AutoModel * Fix missing code and apply `make style` command
This commit is contained in:
parent
e41212c715
commit
2b60a26b46
|
@ -31,28 +31,15 @@ from torch.utils.data.distributed import DistributedSampler
|
|||
from tqdm import tqdm, trange
|
||||
|
||||
from transformers import (
|
||||
ALL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
WEIGHTS_NAME,
|
||||
AdamW,
|
||||
AlbertConfig,
|
||||
AlbertForTokenClassification,
|
||||
AlbertTokenizer,
|
||||
BertConfig,
|
||||
BertForTokenClassification,
|
||||
BertTokenizer,
|
||||
CamembertConfig,
|
||||
CamembertForTokenClassification,
|
||||
CamembertTokenizer,
|
||||
DistilBertConfig,
|
||||
DistilBertForTokenClassification,
|
||||
DistilBertTokenizer,
|
||||
RobertaConfig,
|
||||
RobertaForTokenClassification,
|
||||
RobertaTokenizer,
|
||||
XLMRobertaConfig,
|
||||
XLMRobertaForTokenClassification,
|
||||
XLMRobertaTokenizer,
|
||||
AutoConfig,
|
||||
AutoModelForTokenClassification,
|
||||
AutoTokenizer,
|
||||
get_linear_schedule_with_warmup,
|
||||
)
|
||||
from transformers.modeling_auto import MODEL_MAPPING
|
||||
from utils_ner import convert_examples_to_features, get_labels, read_examples_from_file
|
||||
|
||||
|
||||
|
@ -64,22 +51,8 @@ except ImportError:
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ALL_MODELS = sum(
|
||||
(
|
||||
tuple(conf.pretrained_config_archive_map.keys())
|
||||
for conf in (BertConfig, RobertaConfig, DistilBertConfig, CamembertConfig, XLMRobertaConfig)
|
||||
),
|
||||
(),
|
||||
)
|
||||
|
||||
MODEL_CLASSES = {
|
||||
"albert": (AlbertConfig, AlbertForTokenClassification, AlbertTokenizer),
|
||||
"bert": (BertConfig, BertForTokenClassification, BertTokenizer),
|
||||
"roberta": (RobertaConfig, RobertaForTokenClassification, RobertaTokenizer),
|
||||
"distilbert": (DistilBertConfig, DistilBertForTokenClassification, DistilBertTokenizer),
|
||||
"camembert": (CamembertConfig, CamembertForTokenClassification, CamembertTokenizer),
|
||||
"xlmroberta": (XLMRobertaConfig, XLMRobertaForTokenClassification, XLMRobertaTokenizer),
|
||||
}
|
||||
ALL_MODELS = tuple(ALL_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
MODEL_CLASSES = tuple(m.model_type for m in MODEL_MAPPING)
|
||||
|
||||
TOKENIZER_ARGS = ["do_lower_case", "strip_accents", "keep_accents", "use_fast"]
|
||||
|
||||
|
@ -411,7 +384,7 @@ def main():
|
|||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
|
||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
|
@ -594,8 +567,7 @@ def main():
|
|||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||
|
||||
args.model_type = args.model_type.lower()
|
||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||
config = config_class.from_pretrained(
|
||||
config = AutoConfig.from_pretrained(
|
||||
args.config_name if args.config_name else args.model_name_or_path,
|
||||
num_labels=num_labels,
|
||||
id2label={str(i): label for i, label in enumerate(labels)},
|
||||
|
@ -604,12 +576,12 @@ def main():
|
|||
)
|
||||
tokenizer_args = {k: v for k, v in vars(args).items() if v is not None and k in TOKENIZER_ARGS}
|
||||
logger.info("Tokenizer arguments: %s", tokenizer_args)
|
||||
tokenizer = tokenizer_class.from_pretrained(
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
**tokenizer_args,
|
||||
)
|
||||
model = model_class.from_pretrained(
|
||||
model = AutoModelForTokenClassification.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||
config=config,
|
||||
|
@ -650,7 +622,7 @@ def main():
|
|||
# Evaluation
|
||||
results = {}
|
||||
if args.do_eval and args.local_rank in [-1, 0]:
|
||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir, **tokenizer_args)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.output_dir, **tokenizer_args)
|
||||
checkpoints = [args.output_dir]
|
||||
if args.eval_all_checkpoints:
|
||||
checkpoints = list(
|
||||
|
@ -660,7 +632,7 @@ def main():
|
|||
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
||||
for checkpoint in checkpoints:
|
||||
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
|
||||
model = model_class.from_pretrained(checkpoint)
|
||||
model = AutoModelForTokenClassification.from_pretrained(checkpoint)
|
||||
model.to(args.device)
|
||||
result, _ = evaluate(args, model, tokenizer, labels, pad_token_label_id, mode="dev", prefix=global_step)
|
||||
if global_step:
|
||||
|
@ -672,8 +644,8 @@ def main():
|
|||
writer.write("{} = {}\n".format(key, str(results[key])))
|
||||
|
||||
if args.do_predict and args.local_rank in [-1, 0]:
|
||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir, **tokenizer_args)
|
||||
model = model_class.from_pretrained(args.output_dir)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.output_dir, **tokenizer_args)
|
||||
model = AutoModelForTokenClassification.from_pretrained(args.output_dir)
|
||||
model.to(args.device)
|
||||
result, predictions = evaluate(args, model, tokenizer, labels, pad_token_label_id, mode="test")
|
||||
# Save results
|
||||
|
|
Loading…
Reference in New Issue