diff --git a/examples/pytorch/text-classification/run_glue.py b/examples/pytorch/text-classification/run_glue.py index 257a593834..7ecee51c03 100755 --- a/examples/pytorch/text-classification/run_glue.py +++ b/examples/pytorch/text-classification/run_glue.py @@ -380,6 +380,9 @@ def main(): if label_to_id is not None: model.config.label2id = label_to_id model.config.id2label = {id: label for label, id in config.label2id.items()} + elif data_args.task_name is not None and not is_regression: + model.config.label2id = {l: i for i, l in enumerate(label_list)} + model.config.id2label = {id: label for label, id in config.label2id.items()} if data_args.max_seq_length > tokenizer.model_max_length: logger.warning( diff --git a/examples/pytorch/text-classification/run_glue_no_trainer.py b/examples/pytorch/text-classification/run_glue_no_trainer.py index 75f4ad5bbc..4af109a4bb 100644 --- a/examples/pytorch/text-classification/run_glue_no_trainer.py +++ b/examples/pytorch/text-classification/run_glue_no_trainer.py @@ -288,6 +288,9 @@ def main(): if label_to_id is not None: model.config.label2id = label_to_id model.config.id2label = {id: label for label, id in config.label2id.items()} + elif args.task_name is not None and not is_regression: + model.config.label2id = {l: i for i, l in enumerate(label_list)} + model.config.id2label = {id: label for label, id in config.label2id.items()} padding = "max_length" if args.pad_to_max_length else False diff --git a/examples/tensorflow/text-classification/run_glue.py b/examples/tensorflow/text-classification/run_glue.py index ff482e2cf0..0bc2f8d7e1 100644 --- a/examples/tensorflow/text-classification/run_glue.py +++ b/examples/tensorflow/text-classification/run_glue.py @@ -355,6 +355,9 @@ def main(): if label_to_id is not None: config.label2id = label_to_id config.id2label = {id: label for label, id in config.label2id.items()} + elif data_args.task_name is not None and not is_regression: + config.label2id = {l: i for i, l in enumerate(label_list)} + config.id2label = {id: label for label, id in config.label2id.items()} if data_args.max_seq_length > tokenizer.model_max_length: logger.warning(