Fix `SeamlessM4Tv2ModelIntegrationTest` (#27911)
change dtype of some integration tests
This commit is contained in:
parent
e96c1de191
commit
5e620a92cf
|
@ -1014,8 +1014,9 @@ class SeamlessM4Tv2ModelIntegrationTest(unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
def factory_test_task(self, class1, class2, inputs, class1_kwargs, class2_kwargs):
|
def factory_test_task(self, class1, class2, inputs, class1_kwargs, class2_kwargs):
|
||||||
model1 = class1.from_pretrained(self.repo_id).to(torch_device)
|
# half-precision loading to limit GPU usage
|
||||||
model2 = class2.from_pretrained(self.repo_id).to(torch_device)
|
model1 = class1.from_pretrained(self.repo_id, torch_dtype=torch.float16).to(torch_device)
|
||||||
|
model2 = class2.from_pretrained(self.repo_id, torch_dtype=torch.float16).to(torch_device)
|
||||||
|
|
||||||
set_seed(0)
|
set_seed(0)
|
||||||
output_1 = model1.generate(**inputs, **class1_kwargs)
|
output_1 = model1.generate(**inputs, **class1_kwargs)
|
||||||
|
|
Loading…
Reference in New Issue