Fix `AutoModelTest.test_model_from_pretrained` (#20730)

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2022-12-12 15:37:43 +01:00 committed by GitHub
parent a3345c1f13
commit 5ba2dbd9b1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 5 additions and 2 deletions

View File

@ -22,7 +22,7 @@ from pathlib import Path
import pytest import pytest
from transformers import BertConfig, GPT2Model, is_torch_available from transformers import BertConfig, GPT2Model, is_safetensors_available, is_torch_available
from transformers.models.auto.configuration_auto import CONFIG_MAPPING from transformers.models.auto.configuration_auto import CONFIG_MAPPING
from transformers.testing_utils import ( from transformers.testing_utils import (
DUMMY_UNKNOWN_IDENTIFIER, DUMMY_UNKNOWN_IDENTIFIER,
@ -102,7 +102,10 @@ class AutoModelTest(unittest.TestCase):
self.assertIsInstance(model, BertModel) self.assertIsInstance(model, BertModel)
self.assertEqual(len(loading_info["missing_keys"]), 0) self.assertEqual(len(loading_info["missing_keys"]), 0)
self.assertEqual(len(loading_info["unexpected_keys"]), 8) # When using PyTorch checkpoint, the expected value is `8`. With `safetensors` checkpoint (if it is
# installed), the expected value becomes `7`.
EXPECTED_NUM_OF_UNEXPECTED_KEYS = 7 if is_safetensors_available() else 8
self.assertEqual(len(loading_info["unexpected_keys"]), EXPECTED_NUM_OF_UNEXPECTED_KEYS)
self.assertEqual(len(loading_info["mismatched_keys"]), 0) self.assertEqual(len(loading_info["mismatched_keys"]), 0)
self.assertEqual(len(loading_info["error_msgs"]), 0) self.assertEqual(len(loading_info["error_msgs"]), 0)