From b6d4284b26c0ab5e736cb7838b27b720225feeb7 Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Fri, 13 Dec 2019 22:43:15 -0500 Subject: [PATCH] [cli] Uploads: fix + test edge case --- transformers/hf_api.py | 3 +- transformers/tests/fixtures/empty.txt | 0 transformers/tests/hf_api_test.py | 44 +++++++++++++++++++-------- 3 files changed, 33 insertions(+), 14 deletions(-) create mode 100644 transformers/tests/fixtures/empty.txt diff --git a/transformers/hf_api.py b/transformers/hf_api.py index 3bbb6c567a..170732339a 100644 --- a/transformers/hf_api.py +++ b/transformers/hf_api.py @@ -131,8 +131,9 @@ class HfApi: # the client still has to specify it when uploading the file. with open(filepath, "rb") as f: pf = TqdmProgressFileReader(f) + data = f if pf.total_size > 0 else "" - r = requests.put(urls.write, data=f, headers={ + r = requests.put(urls.write, data=data, headers={ "content-type": urls.type, }) r.raise_for_status() diff --git a/transformers/tests/fixtures/empty.txt b/transformers/tests/fixtures/empty.txt new file mode 100644 index 0000000000..e69de29bb2 diff --git a/transformers/tests/hf_api_test.py b/transformers/tests/hf_api_test.py index 92d41b6dff..b45f5aceed 100644 --- a/transformers/tests/hf_api_test.py +++ b/transformers/tests/hf_api_test.py @@ -15,18 +15,30 @@ from __future__ import absolute_import, division, print_function import os -import six import time import unittest -from transformers.hf_api import HfApi, S3Obj, PresignedUrl, HfFolder, HTTPError +import requests +import six + +from transformers.hf_api import HfApi, HfFolder, HTTPError, PresignedUrl, S3Obj USER = "__DUMMY_TRANSFORMERS_USER__" PASS = "__DUMMY_TRANSFORMERS_PASS__" -FILE_KEY = "Test-{}.txt".format(int(time.time())) -FILE_PATH = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "fixtures/input.txt" -) +FILES = [ + ( + "Test-{}.txt".format(int(time.time())), + os.path.join( + os.path.dirname(os.path.abspath(__file__)), "fixtures/input.txt" + ) + ), + ( + "yoyo {}.txt".format(int(time.time())), # space is intentional + os.path.join( + os.path.dirname(os.path.abspath(__file__)), "fixtures/empty.txt" + ) + ), +] @@ -57,15 +69,21 @@ class HfApiEndpointsTest(HfApiCommonTest): self.assertEqual(user, USER) def test_presign(self): - urls = self._api.presign(token=self._token, filename=FILE_KEY) - self.assertIsInstance(urls, PresignedUrl) - self.assertEqual(urls.type, "text/plain") + for FILE_KEY, FILE_PATH in FILES: + urls = self._api.presign(token=self._token, filename=FILE_KEY) + self.assertIsInstance(urls, PresignedUrl) + self.assertEqual(urls.type, "text/plain") def test_presign_and_upload(self): - access_url = self._api.presign_and_upload( - token=self._token, filename=FILE_KEY, filepath=FILE_PATH - ) - self.assertIsInstance(access_url, six.string_types) + for FILE_KEY, FILE_PATH in FILES: + access_url = self._api.presign_and_upload( + token=self._token, filename=FILE_KEY, filepath=FILE_PATH + ) + self.assertIsInstance(access_url, six.string_types) + with open(FILE_PATH, 'r') as f: + body = f.read() + r = requests.get(access_url) + self.assertEqual(r.text, body) def test_list_objs(self): objs = self._api.list_objs(token=self._token)