Merge pull request #113 from hzhwcmhf/master

fix compatibility with python 3.5.2
This commit is contained in:
Thomas Wolf 2018-12-13 12:15:15 +01:00 committed by GitHub
commit 32a227f507
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 9 additions and 3 deletions

View File

@ -45,13 +45,15 @@ def url_to_filename(url: str, etag: str = None) -> str:
return filename return filename
def filename_to_url(filename: str, cache_dir: str = None) -> Tuple[str, str]: def filename_to_url(filename: str, cache_dir: Union[str, Path] = None) -> Tuple[str, str]:
""" """
Return the url and etag (which may be ``None``) stored for `filename`. Return the url and etag (which may be ``None``) stored for `filename`.
Raise ``FileNotFoundError`` if `filename` or its stored metadata do not exist. Raise ``FileNotFoundError`` if `filename` or its stored metadata do not exist.
""" """
if cache_dir is None: if cache_dir is None:
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
cache_path = os.path.join(cache_dir, filename) cache_path = os.path.join(cache_dir, filename)
if not os.path.exists(cache_path): if not os.path.exists(cache_path):
@ -69,7 +71,7 @@ def filename_to_url(filename: str, cache_dir: str = None) -> Tuple[str, str]:
return url, etag return url, etag
def cached_path(url_or_filename: Union[str, Path], cache_dir: str = None) -> str: def cached_path(url_or_filename: Union[str, Path], cache_dir: Union[str, Path] = None) -> str:
""" """
Given something that might be a URL (or might be a local path), Given something that might be a URL (or might be a local path),
determine which. If it's a URL, download the file and cache it, and determine which. If it's a URL, download the file and cache it, and
@ -80,6 +82,8 @@ def cached_path(url_or_filename: Union[str, Path], cache_dir: str = None) -> str
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
if isinstance(url_or_filename, Path): if isinstance(url_or_filename, Path):
url_or_filename = str(url_or_filename) url_or_filename = str(url_or_filename)
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
parsed = urlparse(url_or_filename) parsed = urlparse(url_or_filename)
@ -158,13 +162,15 @@ def http_get(url: str, temp_file: IO) -> None:
progress.close() progress.close()
def get_from_cache(url: str, cache_dir: str = None) -> str: def get_from_cache(url: str, cache_dir: Union[str, Path] = None) -> str:
""" """
Given a URL, look for the corresponding dataset in the local cache. Given a URL, look for the corresponding dataset in the local cache.
If it's not there, download it. Then return the path to the cached file. If it's not there, download it. Then return the path to the cached file.
""" """
if cache_dir is None: if cache_dir is None:
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
os.makedirs(cache_dir, exist_ok=True) os.makedirs(cache_dir, exist_ok=True)