From 3a9476d1b412274bcc51143acaaee187e9d18120 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Fri, 2 Dec 2022 09:05:45 +0100 Subject: [PATCH] fix cuda OOM by using single Prior (#20486) * fix cuda OOM by using single Prior * only send to device when used * use custom model --- tests/models/jukebox/test_modeling_jukebox.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/tests/models/jukebox/test_modeling_jukebox.py b/tests/models/jukebox/test_modeling_jukebox.py index 9232119432..af5f946b9f 100644 --- a/tests/models/jukebox/test_modeling_jukebox.py +++ b/tests/models/jukebox/test_modeling_jukebox.py @@ -22,7 +22,7 @@ from transformers.trainer_utils import set_seed if is_torch_available(): import torch - from transformers import JukeboxModel, JukeboxTokenizer + from transformers import JukeboxModel, JukeboxPrior, JukeboxTokenizer @require_torch @@ -312,7 +312,7 @@ class Jukebox5bModelTester(unittest.TestCase): @slow def test_slow_sampling(self): - model = JukeboxModel.from_pretrained(self.model_id, min_duration=0).eval().to("cuda") + model = JukeboxModel.from_pretrained(self.model_id, min_duration=0).eval() labels = [i.cuda() for i in self.prepare_inputs(self.model_id)] set_seed(0) @@ -335,10 +335,11 @@ class Jukebox5bModelTester(unittest.TestCase): @slow def test_fp16_slow_sampling(self): - model = JukeboxModel.from_pretrained(self.model_id, min_duration=0).eval().half().to("cuda") - labels = [i.cuda() for i in self.prepare_inputs(self.model_id)] + prior_id = "ArthurZ/jukebox_prior_0" + model = JukeboxPrior.from_pretrained(prior_id, min_duration=0).eval().half().to("cuda") + labels = self.prepare_inputs(prior_id)[0].cuda() + metadata = model.get_metadata(labels, 0, 7680, 0) set_seed(0) - zs = [torch.zeros(1, 0, dtype=torch.long).cuda() for _ in range(3)] - zs = model._sample(zs, labels, [0], sample_length=60 * model.priors[0].raw_to_tokens, save_results=False) - torch.testing.assert_allclose(zs[0][0].cpu(), torch.tensor(self.EXPECTED_GPU_OUTPUTS_2)) + outputs = model.sample(1, metadata=metadata, sample_tokens=60) + torch.testing.assert_allclose(outputs[0].cpu(), torch.tensor(self.EXPECTED_GPU_OUTPUTS_2))