[test] add test for --config_overrides (#14466)
* add test for --config_overrides * remove unneeded parts of the test
This commit is contained in:
parent
e0e2da1194
commit
11f65d4158
|
@ -324,6 +324,7 @@ def main():
|
|||
if model_args.config_overrides is not None:
|
||||
logger.info(f"Overriding config: {model_args.config_overrides}")
|
||||
config.update_from_string(model_args.config_overrides)
|
||||
logger.info(f"New config: {config}")
|
||||
|
||||
tokenizer_kwargs = {
|
||||
"cache_dir": model_args.cache_dir,
|
||||
|
|
|
@ -326,6 +326,7 @@ def main():
|
|||
if model_args.config_overrides is not None:
|
||||
logger.info(f"Overriding config: {model_args.config_overrides}")
|
||||
config.update_from_string(model_args.config_overrides)
|
||||
logger.info(f"New config: {config}")
|
||||
|
||||
tokenizer_kwargs = {
|
||||
"cache_dir": model_args.cache_dir,
|
||||
|
|
|
@ -318,6 +318,7 @@ def main():
|
|||
if model_args.config_overrides is not None:
|
||||
logger.info(f"Overriding config: {model_args.config_overrides}")
|
||||
config.update_from_string(model_args.config_overrides)
|
||||
logger.info(f"New config: {config}")
|
||||
|
||||
tokenizer_kwargs = {
|
||||
"cache_dir": model_args.cache_dir,
|
||||
|
|
|
@ -25,7 +25,7 @@ import torch
|
|||
|
||||
from transformers import Wav2Vec2ForPreTraining
|
||||
from transformers.file_utils import is_apex_available
|
||||
from transformers.testing_utils import TestCasePlus, get_gpu_count, slow, torch_device
|
||||
from transformers.testing_utils import CaptureLogger, TestCasePlus, get_gpu_count, slow, torch_device
|
||||
|
||||
|
||||
SRC_DIRS = [
|
||||
|
@ -157,6 +157,31 @@ class ExamplesTests(TestCasePlus):
|
|||
result = get_results(tmp_dir)
|
||||
self.assertLess(result["perplexity"], 100)
|
||||
|
||||
def test_run_clm_config_overrides(self):
|
||||
# test that config_overrides works, despite the misleading dumps of default un-updated
|
||||
# config via tokenizer
|
||||
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
testargs = f"""
|
||||
run_clm.py
|
||||
--model_type gpt2
|
||||
--tokenizer_name gpt2
|
||||
--train_file ./tests/fixtures/sample_text.txt
|
||||
--output_dir {tmp_dir}
|
||||
--config_overrides n_embd=10,n_head=2
|
||||
""".split()
|
||||
|
||||
if torch_device != "cuda":
|
||||
testargs.append("--no_cuda")
|
||||
|
||||
logger = run_clm.logger
|
||||
with patch.object(sys, "argv", testargs):
|
||||
with CaptureLogger(logger) as cl:
|
||||
run_clm.main()
|
||||
|
||||
self.assertIn('"n_embd": 10', cl.out)
|
||||
self.assertIn('"n_head": 2', cl.out)
|
||||
|
||||
def test_run_mlm(self):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
|
Loading…
Reference in New Issue