Fix `AutoModelTest.test_model_from_pretrained` (#20730)
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
a3345c1f13
commit
5ba2dbd9b1
|
@ -22,7 +22,7 @@ from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from transformers import BertConfig, GPT2Model, is_torch_available
|
from transformers import BertConfig, GPT2Model, is_safetensors_available, is_torch_available
|
||||||
from transformers.models.auto.configuration_auto import CONFIG_MAPPING
|
from transformers.models.auto.configuration_auto import CONFIG_MAPPING
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
DUMMY_UNKNOWN_IDENTIFIER,
|
DUMMY_UNKNOWN_IDENTIFIER,
|
||||||
|
@ -102,7 +102,10 @@ class AutoModelTest(unittest.TestCase):
|
||||||
self.assertIsInstance(model, BertModel)
|
self.assertIsInstance(model, BertModel)
|
||||||
|
|
||||||
self.assertEqual(len(loading_info["missing_keys"]), 0)
|
self.assertEqual(len(loading_info["missing_keys"]), 0)
|
||||||
self.assertEqual(len(loading_info["unexpected_keys"]), 8)
|
# When using PyTorch checkpoint, the expected value is `8`. With `safetensors` checkpoint (if it is
|
||||||
|
# installed), the expected value becomes `7`.
|
||||||
|
EXPECTED_NUM_OF_UNEXPECTED_KEYS = 7 if is_safetensors_available() else 8
|
||||||
|
self.assertEqual(len(loading_info["unexpected_keys"]), EXPECTED_NUM_OF_UNEXPECTED_KEYS)
|
||||||
self.assertEqual(len(loading_info["mismatched_keys"]), 0)
|
self.assertEqual(len(loading_info["mismatched_keys"]), 0)
|
||||||
self.assertEqual(len(loading_info["error_msgs"]), 0)
|
self.assertEqual(len(loading_info["error_msgs"]), 0)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue