More AutoConfig tests

This commit is contained in:
Julien Chaumond 2020-01-11 03:43:57 +00:00
parent 6bb3edc300
commit cf8a70bf68
5 changed files with 35 additions and 4 deletions

View File

@ -20,6 +20,8 @@ from transformers.configuration_auto import AutoConfig
from transformers.configuration_bert import BertConfig
from transformers.configuration_roberta import RobertaConfig
from .utils import DUMMY_UNKWOWN_IDENTIFIER
SAMPLE_ROBERTA_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy-config.json")
@ -29,10 +31,14 @@ class AutoConfigTest(unittest.TestCase):
config = AutoConfig.from_pretrained("bert-base-uncased")
self.assertIsInstance(config, BertConfig)
def test_config_from_model_type(self):
def test_config_model_type_from_local_file(self):
config = AutoConfig.from_pretrained(SAMPLE_ROBERTA_CONFIG)
self.assertIsInstance(config, RobertaConfig)
def test_config_model_type_from_model_identifier(self):
config = AutoConfig.from_pretrained(DUMMY_UNKWOWN_IDENTIFIER)
self.assertIsInstance(config, RobertaConfig)
def test_config_for_model_str(self):
config = AutoConfig.for_model("roberta")
self.assertIsInstance(config, RobertaConfig)

View File

@ -19,7 +19,7 @@ import unittest
from transformers import is_torch_available
from .utils import SMALL_MODEL_IDENTIFIER, require_torch, slow
from .utils import DUMMY_UNKWOWN_IDENTIFIER, SMALL_MODEL_IDENTIFIER, require_torch, slow
if is_torch_available():
@ -30,6 +30,7 @@ if is_torch_available():
BertModel,
AutoModelWithLMHead,
BertForMaskedLM,
RobertaForMaskedLM,
AutoModelForSequenceClassification,
BertForSequenceClassification,
AutoModelForQuestionAnswering,
@ -102,3 +103,10 @@ class AutoModelTest(unittest.TestCase):
self.assertIsInstance(model, BertForMaskedLM)
self.assertEqual(model.num_parameters(), 14830)
self.assertEqual(model.num_parameters(only_trainable=True), 14830)
def test_from_identifier_from_model_type(self):
logging.basicConfig(level=logging.INFO)
model = AutoModelWithLMHead.from_pretrained(DUMMY_UNKWOWN_IDENTIFIER)
self.assertIsInstance(model, RobertaForMaskedLM)
self.assertEqual(model.num_parameters(), 14830)
self.assertEqual(model.num_parameters(only_trainable=True), 14830)

View File

@ -19,7 +19,7 @@ import unittest
from transformers import is_tf_available
from .utils import SMALL_MODEL_IDENTIFIER, require_tf, slow
from .utils import DUMMY_UNKWOWN_IDENTIFIER, SMALL_MODEL_IDENTIFIER, require_tf, slow
if is_tf_available():
@ -30,6 +30,7 @@ if is_tf_available():
TFBertModel,
TFAutoModelWithLMHead,
TFBertForMaskedLM,
TFRobertaForMaskedLM,
TFAutoModelForSequenceClassification,
TFBertForSequenceClassification,
TFAutoModelForQuestionAnswering,
@ -101,3 +102,10 @@ class TFAutoModelTest(unittest.TestCase):
self.assertIsInstance(model, TFBertForMaskedLM)
self.assertEqual(model.num_parameters(), 14830)
self.assertEqual(model.num_parameters(only_trainable=True), 14830)
def test_from_identifier_from_model_type(self):
logging.basicConfig(level=logging.INFO)
model = TFAutoModelWithLMHead.from_pretrained(DUMMY_UNKWOWN_IDENTIFIER)
self.assertIsInstance(model, TFRobertaForMaskedLM)
self.assertEqual(model.num_parameters(), 14830)
self.assertEqual(model.num_parameters(only_trainable=True), 14830)

View File

@ -23,9 +23,10 @@ from transformers import (
AutoTokenizer,
BertTokenizer,
GPT2Tokenizer,
RobertaTokenizer,
)
from .utils import SMALL_MODEL_IDENTIFIER, slow # noqa: F401
from .utils import DUMMY_UNKWOWN_IDENTIFIER, SMALL_MODEL_IDENTIFIER, slow # noqa: F401
class AutoTokenizerTest(unittest.TestCase):
@ -49,3 +50,9 @@ class AutoTokenizerTest(unittest.TestCase):
tokenizer = AutoTokenizer.from_pretrained(SMALL_MODEL_IDENTIFIER)
self.assertIsInstance(tokenizer, BertTokenizer)
self.assertEqual(len(tokenizer), 12)
def test_tokenizer_from_model_type(self):
logging.basicConfig(level=logging.INFO)
tokenizer = AutoTokenizer.from_pretrained(DUMMY_UNKWOWN_IDENTIFIER)
self.assertIsInstance(tokenizer, RobertaTokenizer)
self.assertEqual(len(tokenizer), 20)

View File

@ -9,6 +9,8 @@ from transformers.file_utils import _tf_available, _torch_available
CACHE_DIR = os.path.join(tempfile.gettempdir(), "transformers_test")
SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy"
DUMMY_UNKWOWN_IDENTIFIER = "julien-c/dummy-unknown"
# Used to test Auto{Config, Model, Tokenizer} model_type detection.
def parse_flag_from_env(key, default=False):