Improve mismatched sizes management when loading a pretrained model (#17257)

- Add --ignore_mismatched_sizes argument to classification examples

- Expand the error message when loading a model whose head dimensions are different from expected dimensions
This commit is contained in:
regisss 2022-05-17 17:58:14 +02:00 committed by GitHub
parent 1f13ba818e
commit 28a0811652
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 64 additions and 9 deletions

View File

@ -18,13 +18,13 @@ limitations under the License.
The following examples showcase how to fine-tune `Wav2Vec2` for audio classification using PyTorch.
Speech recognition models that have been pretrained in unsupervised fashion on audio data alone,
*e.g.* [Wav2Vec2](https://huggingface.co/transformers/main/model_doc/wav2vec2.html),
[HuBERT](https://huggingface.co/transformers/main/model_doc/hubert.html),
[XLSR-Wav2Vec2](https://huggingface.co/transformers/main/model_doc/xlsr_wav2vec2.html), have shown to require only
Speech recognition models that have been pretrained in unsupervised fashion on audio data alone,
*e.g.* [Wav2Vec2](https://huggingface.co/transformers/main/model_doc/wav2vec2.html),
[HuBERT](https://huggingface.co/transformers/main/model_doc/hubert.html),
[XLSR-Wav2Vec2](https://huggingface.co/transformers/main/model_doc/xlsr_wav2vec2.html), have shown to require only
very little annotated data to yield good performance on speech classification datasets.
## Single-GPU
## Single-GPU
The following command shows how to fine-tune [wav2vec2-base](https://huggingface.co/facebook/wav2vec2-base) on the 🗣️ [Keyword Spotting subset](https://huggingface.co/datasets/superb#ks) of the SUPERB dataset.
@ -63,7 +63,9 @@ On a single V100 GPU (16GB), this script should run in ~14 minutes and yield acc
👀 See the results here: [anton-l/wav2vec2-base-ft-keyword-spotting](https://huggingface.co/anton-l/wav2vec2-base-ft-keyword-spotting)
## Multi-GPU
> If your model classification head dimensions do not fit the number of labels in the dataset, you can specify `--ignore_mismatched_sizes` to adapt it.
## Multi-GPU
The following command shows how to fine-tune [wav2vec2-base](https://huggingface.co/facebook/wav2vec2-base) for 🌎 **Language Identification** on the [CommonLanguage dataset](https://huggingface.co/datasets/anton-l/common_language).
@ -139,7 +141,7 @@ It has been verified that the script works for the following datasets:
| Dataset | Pretrained Model | # transformer layers | Accuracy on eval | GPU setup | Training time | Fine-tuned Model & Logs |
|---------|------------------|----------------------|------------------|-----------|---------------|--------------------------|
| Keyword Spotting | [ntu-spml/distilhubert](https://huggingface.co/ntu-spml/distilhubert) | 2 | 0.9706 | 1 V100 GPU | 11min | [here](https://huggingface.co/anton-l/distilhubert-ft-keyword-spotting) |
| Keyword Spotting | [ntu-spml/distilhubert](https://huggingface.co/ntu-spml/distilhubert) | 2 | 0.9706 | 1 V100 GPU | 11min | [here](https://huggingface.co/anton-l/distilhubert-ft-keyword-spotting) |
| Keyword Spotting | [facebook/wav2vec2-base](https://huggingface.co/facebook/wav2vec2-base) | 12 | 0.9826 | 1 V100 GPU | 14min | [here](https://huggingface.co/anton-l/wav2vec2-base-ft-keyword-spotting) |
| Keyword Spotting | [facebook/hubert-base-ls960](https://huggingface.co/facebook/hubert-base-ls960) | 12 | 0.9819 | 1 V100 GPU | 14min | [here](https://huggingface.co/anton-l/hubert-base-ft-keyword-spotting) |
| Keyword Spotting | [asapp/sew-mid-100k](https://huggingface.co/asapp/sew-mid-100k) | 24 | 0.9757 | 1 V100 GPU | 15min | [here](https://huggingface.co/anton-l/sew-mid-100k-ft-keyword-spotting) |

View File

@ -163,6 +163,10 @@ class ModelArguments:
freeze_feature_extractor: Optional[bool] = field(
default=None, metadata={"help": "Whether to freeze the feature extractor layers of the model."}
)
ignore_mismatched_sizes: bool = field(
default=False,
metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."},
)
def __post_init__(self):
if not self.freeze_feature_extractor and self.freeze_feature_encoder:
@ -333,6 +337,7 @@ def main():
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
)
# freeze the convolutional waveform encoder

View File

@ -62,9 +62,11 @@ python run_image_classification.py \
Note that you can replace the model and dataset by simply setting the `model_name_or_path` and `dataset_name` arguments respectively, with any model or dataset from the [hub](https://huggingface.co/). For an overview of all possible arguments, we refer to the [docs](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments) of the `TrainingArguments`, which can be passed as flags.
> If your model classification head dimensions do not fit the number of labels in the dataset, you can specify `--ignore_mismatched_sizes` to adapt it.
### Using your own data
To use your own dataset, there are 2 ways:
To use your own dataset, there are 2 ways:
- you can either provide your own folders as `--train_dir` and/or `--validation_dir` arguments
- you can upload your dataset to the hub (possibly as a private repo, if you prefer so), and simply pass the `--dataset_name` argument.

View File

@ -150,6 +150,10 @@ class ModelArguments:
)
},
)
ignore_mismatched_sizes: bool = field(
default=False,
metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."},
)
def collate_fn(examples):
@ -269,6 +273,7 @@ def main():
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
)
feature_extractor = AutoFeatureExtractor.from_pretrained(
model_args.feature_extractor_name or model_args.model_name_or_path,

View File

@ -165,6 +165,11 @@ def parse_args():
action="store_true",
help="Whether to load in all available experiment trackers from the environment and use them for logging.",
)
parser.add_argument(
"--ignore_mismatched_sizes",
action="store_true",
help="Whether or not to enable to load a pretrained model whose head dimensions are different.",
)
args = parser.parse_args()
# Sanity checks
@ -278,6 +283,7 @@ def main():
args.model_name_or_path,
from_tf=bool(".ckpt" in args.model_name_or_path),
config=config,
ignore_mismatched_sizes=args.ignore_mismatched_sizes,
)
# Preprocessing the datasets

View File

@ -22,7 +22,7 @@ Based on the script [`run_glue.py`](https://github.com/huggingface/transformers/
Fine-tuning the library models for sequence classification on the GLUE benchmark: [General Language Understanding
Evaluation](https://gluebenchmark.com/). This script can fine-tune any of the models on the [hub](https://huggingface.co/models)
and can also be used for a dataset hosted on our [hub](https://huggingface.co/datasets) or your own data in a csv or a JSON file
and can also be used for a dataset hosted on our [hub](https://huggingface.co/datasets) or your own data in a csv or a JSON file
(the script might need some tweaks in that case, refer to the comments inside for help).
GLUE is made up of a total of 9 different tasks. Here is how to run the script on one of them:
@ -79,6 +79,8 @@ python run_glue.py \
--output_dir /tmp/imdb/
```
> If your model classification head dimensions do not fit the number of labels in the dataset, you can specify `--ignore_mismatched_sizes` to adapt it.
### Mixed precision training

View File

@ -196,6 +196,10 @@ class ModelArguments:
)
},
)
ignore_mismatched_sizes: bool = field(
default=False,
metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."},
)
def main():
@ -364,6 +368,7 @@ def main():
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
)
# Preprocessing the raw_datasets

View File

@ -170,6 +170,11 @@ def parse_args():
action="store_true",
help="Whether to load in all available experiment trackers from the environment and use them for logging.",
)
parser.add_argument(
"--ignore_mismatched_sizes",
action="store_true",
help="Whether or not to enable to load a pretrained model whose head dimensions are different.",
)
args = parser.parse_args()
# Sanity checks
@ -288,6 +293,7 @@ def main():
args.model_name_or_path,
from_tf=bool(".ckpt" in args.model_name_or_path),
config=config,
ignore_mismatched_sizes=args.ignore_mismatched_sizes,
)
# Preprocessing the datasets

View File

@ -162,6 +162,10 @@ class ModelArguments:
)
},
)
ignore_mismatched_sizes: bool = field(
default=False,
metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."},
)
def main():
@ -291,6 +295,7 @@ def main():
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
)
# Preprocessing the datasets

View File

@ -55,6 +55,8 @@ uses special features of those tokenizers. You can check if your favorite model
[this table](https://huggingface.co/transformers/index.html#supported-frameworks), if it doesn't you can still use the old version
of the script.
> If your model classification head dimensions do not fit the number of labels in the dataset, you can specify `--ignore_mismatched_sizes` to adapt it.
## Old version of the script
You can find the old version of the PyTorch script [here](https://github.com/huggingface/transformers/blob/main/examples/legacy/token-classification/run_ner.py).

View File

@ -87,6 +87,10 @@ class ModelArguments:
)
},
)
ignore_mismatched_sizes: bool = field(
default=False,
metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."},
)
@dataclass
@ -364,6 +368,7 @@ def main():
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
)
# Tokenizer check: this script requires a fast tokenizer.

View File

@ -223,6 +223,11 @@ def parse_args():
action="store_true",
help="Whether to load in all available experiment trackers from the environment and use them for logging.",
)
parser.add_argument(
"--ignore_mismatched_sizes",
action="store_true",
help="Whether or not to enable to load a pretrained model whose head dimensions are different.",
)
args = parser.parse_args()
# Sanity checks
@ -383,6 +388,7 @@ def main():
args.model_name_or_path,
from_tf=bool(".ckpt" in args.model_name_or_path),
config=config,
ignore_mismatched_sizes=args.ignore_mismatched_sizes,
)
else:
logger.info("Training new model from scratch")

View File

@ -2253,6 +2253,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if len(error_msgs) > 0:
error_msg = "\n\t".join(error_msgs)
if "size mismatch" in error_msg:
error_msg += (
"\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
)
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
if len(unexpected_keys) > 0: