More AutoConfig tests
This commit is contained in:
parent
6bb3edc300
commit
cf8a70bf68
|
@ -20,6 +20,8 @@ from transformers.configuration_auto import AutoConfig
|
|||
from transformers.configuration_bert import BertConfig
|
||||
from transformers.configuration_roberta import RobertaConfig
|
||||
|
||||
from .utils import DUMMY_UNKWOWN_IDENTIFIER
|
||||
|
||||
|
||||
SAMPLE_ROBERTA_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy-config.json")
|
||||
|
||||
|
@ -29,10 +31,14 @@ class AutoConfigTest(unittest.TestCase):
|
|||
config = AutoConfig.from_pretrained("bert-base-uncased")
|
||||
self.assertIsInstance(config, BertConfig)
|
||||
|
||||
def test_config_from_model_type(self):
|
||||
def test_config_model_type_from_local_file(self):
|
||||
config = AutoConfig.from_pretrained(SAMPLE_ROBERTA_CONFIG)
|
||||
self.assertIsInstance(config, RobertaConfig)
|
||||
|
||||
def test_config_model_type_from_model_identifier(self):
|
||||
config = AutoConfig.from_pretrained(DUMMY_UNKWOWN_IDENTIFIER)
|
||||
self.assertIsInstance(config, RobertaConfig)
|
||||
|
||||
def test_config_for_model_str(self):
|
||||
config = AutoConfig.for_model("roberta")
|
||||
self.assertIsInstance(config, RobertaConfig)
|
||||
|
|
|
@ -19,7 +19,7 @@ import unittest
|
|||
|
||||
from transformers import is_torch_available
|
||||
|
||||
from .utils import SMALL_MODEL_IDENTIFIER, require_torch, slow
|
||||
from .utils import DUMMY_UNKWOWN_IDENTIFIER, SMALL_MODEL_IDENTIFIER, require_torch, slow
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
|
@ -30,6 +30,7 @@ if is_torch_available():
|
|||
BertModel,
|
||||
AutoModelWithLMHead,
|
||||
BertForMaskedLM,
|
||||
RobertaForMaskedLM,
|
||||
AutoModelForSequenceClassification,
|
||||
BertForSequenceClassification,
|
||||
AutoModelForQuestionAnswering,
|
||||
|
@ -102,3 +103,10 @@ class AutoModelTest(unittest.TestCase):
|
|||
self.assertIsInstance(model, BertForMaskedLM)
|
||||
self.assertEqual(model.num_parameters(), 14830)
|
||||
self.assertEqual(model.num_parameters(only_trainable=True), 14830)
|
||||
|
||||
def test_from_identifier_from_model_type(self):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
model = AutoModelWithLMHead.from_pretrained(DUMMY_UNKWOWN_IDENTIFIER)
|
||||
self.assertIsInstance(model, RobertaForMaskedLM)
|
||||
self.assertEqual(model.num_parameters(), 14830)
|
||||
self.assertEqual(model.num_parameters(only_trainable=True), 14830)
|
||||
|
|
|
@ -19,7 +19,7 @@ import unittest
|
|||
|
||||
from transformers import is_tf_available
|
||||
|
||||
from .utils import SMALL_MODEL_IDENTIFIER, require_tf, slow
|
||||
from .utils import DUMMY_UNKWOWN_IDENTIFIER, SMALL_MODEL_IDENTIFIER, require_tf, slow
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
|
@ -30,6 +30,7 @@ if is_tf_available():
|
|||
TFBertModel,
|
||||
TFAutoModelWithLMHead,
|
||||
TFBertForMaskedLM,
|
||||
TFRobertaForMaskedLM,
|
||||
TFAutoModelForSequenceClassification,
|
||||
TFBertForSequenceClassification,
|
||||
TFAutoModelForQuestionAnswering,
|
||||
|
@ -101,3 +102,10 @@ class TFAutoModelTest(unittest.TestCase):
|
|||
self.assertIsInstance(model, TFBertForMaskedLM)
|
||||
self.assertEqual(model.num_parameters(), 14830)
|
||||
self.assertEqual(model.num_parameters(only_trainable=True), 14830)
|
||||
|
||||
def test_from_identifier_from_model_type(self):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
model = TFAutoModelWithLMHead.from_pretrained(DUMMY_UNKWOWN_IDENTIFIER)
|
||||
self.assertIsInstance(model, TFRobertaForMaskedLM)
|
||||
self.assertEqual(model.num_parameters(), 14830)
|
||||
self.assertEqual(model.num_parameters(only_trainable=True), 14830)
|
||||
|
|
|
@ -23,9 +23,10 @@ from transformers import (
|
|||
AutoTokenizer,
|
||||
BertTokenizer,
|
||||
GPT2Tokenizer,
|
||||
RobertaTokenizer,
|
||||
)
|
||||
|
||||
from .utils import SMALL_MODEL_IDENTIFIER, slow # noqa: F401
|
||||
from .utils import DUMMY_UNKWOWN_IDENTIFIER, SMALL_MODEL_IDENTIFIER, slow # noqa: F401
|
||||
|
||||
|
||||
class AutoTokenizerTest(unittest.TestCase):
|
||||
|
@ -49,3 +50,9 @@ class AutoTokenizerTest(unittest.TestCase):
|
|||
tokenizer = AutoTokenizer.from_pretrained(SMALL_MODEL_IDENTIFIER)
|
||||
self.assertIsInstance(tokenizer, BertTokenizer)
|
||||
self.assertEqual(len(tokenizer), 12)
|
||||
|
||||
def test_tokenizer_from_model_type(self):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
tokenizer = AutoTokenizer.from_pretrained(DUMMY_UNKWOWN_IDENTIFIER)
|
||||
self.assertIsInstance(tokenizer, RobertaTokenizer)
|
||||
self.assertEqual(len(tokenizer), 20)
|
||||
|
|
|
@ -9,6 +9,8 @@ from transformers.file_utils import _tf_available, _torch_available
|
|||
CACHE_DIR = os.path.join(tempfile.gettempdir(), "transformers_test")
|
||||
|
||||
SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy"
|
||||
DUMMY_UNKWOWN_IDENTIFIER = "julien-c/dummy-unknown"
|
||||
# Used to test Auto{Config, Model, Tokenizer} model_type detection.
|
||||
|
||||
|
||||
def parse_flag_from_env(key, default=False):
|
||||
|
|
Loading…
Reference in New Issue