Fix `SeamlessM4Tv2ModelIntegrationTest` (#27911)

change dtype of some integration tests
This commit is contained in:
Yoach Lacombe 2023-12-11 08:18:41 +00:00 committed by GitHub
parent e96c1de191
commit 5e620a92cf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 2 deletions

View File

@ -1014,8 +1014,9 @@ class SeamlessM4Tv2ModelIntegrationTest(unittest.TestCase):
) )
def factory_test_task(self, class1, class2, inputs, class1_kwargs, class2_kwargs): def factory_test_task(self, class1, class2, inputs, class1_kwargs, class2_kwargs):
model1 = class1.from_pretrained(self.repo_id).to(torch_device) # half-precision loading to limit GPU usage
model2 = class2.from_pretrained(self.repo_id).to(torch_device) model1 = class1.from_pretrained(self.repo_id, torch_dtype=torch.float16).to(torch_device)
model2 = class2.from_pretrained(self.repo_id, torch_dtype=torch.float16).to(torch_device)
set_seed(0) set_seed(0)
output_1 = model1.generate(**inputs, **class1_kwargs) output_1 = model1.generate(**inputs, **class1_kwargs)