Add test when downloading from gated repo (#25039)

This commit is contained in:
Lucain 2023-07-28 14:14:27 +02:00 committed by GitHub
parent 6232c380f2
commit c1dba1111b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 13 additions and 0 deletions

View File

@ -36,6 +36,9 @@ RANDOM_BERT = "hf-internal-testing/tiny-random-bert"
CACHE_DIR = os.path.join(TRANSFORMERS_CACHE, "models--hf-internal-testing--tiny-random-bert")
FULL_COMMIT_HASH = "9b8c223d42b2188cb49d29af482996f9d0f3e5a6"
GATED_REPO = "hf-internal-testing/dummy-gated-model"
README_FILE = "README.md"
class GetFromCacheTests(unittest.TestCase):
def test_cached_file(self):
@ -124,3 +127,13 @@ class GetFromCacheTests(unittest.TestCase):
self.assertEqual(get_file_from_repo(tmp_dir, "a.txt"), str(filename))
self.assertIsNone(get_file_from_repo(tmp_dir, "b.txt"))
def test_get_file_gated_repo(self):
"""Test download file from a gated repo fails with correct message when not authenticated."""
with self.assertRaisesRegex(EnvironmentError, "You are trying to access a gated repo."):
cached_file(GATED_REPO, README_FILE, use_auth_token=False)
def test_has_file_gated_repo(self):
"""Test check file existence from a gated repo fails with correct message when not authenticated."""
with self.assertRaisesRegex(EnvironmentError, "is a gated repository"):
has_file(GATED_REPO, README_FILE, use_auth_token=False)