Fix `SeamlessM4Tv2ModelIntegrationTest` (#27911)
change dtype of some integration tests
This commit is contained in:
parent
e96c1de191
commit
5e620a92cf
|
@ -1014,8 +1014,9 @@ class SeamlessM4Tv2ModelIntegrationTest(unittest.TestCase):
|
|||
)
|
||||
|
||||
def factory_test_task(self, class1, class2, inputs, class1_kwargs, class2_kwargs):
|
||||
model1 = class1.from_pretrained(self.repo_id).to(torch_device)
|
||||
model2 = class2.from_pretrained(self.repo_id).to(torch_device)
|
||||
# half-precision loading to limit GPU usage
|
||||
model1 = class1.from_pretrained(self.repo_id, torch_dtype=torch.float16).to(torch_device)
|
||||
model2 = class2.from_pretrained(self.repo_id, torch_dtype=torch.float16).to(torch_device)
|
||||
|
||||
set_seed(0)
|
||||
output_1 = model1.generate(**inputs, **class1_kwargs)
|
||||
|
|
Loading…
Reference in New Issue