Fix `PersimmonIntegrationTest` OOM (#26750)

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2023-10-12 11:24:18 +02:00 committed by GitHub
parent ab0ddc99e8
commit 72256bc72a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 20 additions and 7 deletions

View File

@ -15,6 +15,7 @@
""" Testing suite for the PyTorch Persimmon model. """
import gc
import unittest
from parameterized import parameterized
@ -395,19 +396,27 @@ class PersimmonIntegrationTest(unittest.TestCase):
def test_model_8b_chat_logits(self):
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338]
model = PersimmonForCausalLM.from_pretrained(
"adept/persimmon-8b-chat", device_map="auto", torch_dtype=torch.float16
"adept/persimmon-8b-chat", load_in_8bit=True, device_map={"": 0}, torch_dtype=torch.float16
)
out = model(torch.tensor([input_ids])).logits
out = model(torch.tensor([input_ids], device=torch_device)).logits
EXPECTED_MEAN = torch.tensor(
[[-11.2879, -11.2628, -11.2498, -11.2534, -11.2676, -11.2638, -11.2501, -11.2431]], dtype=torch.float16
[[-11.4726, -11.1495, -11.2694, -11.2223, -10.9452, -11.0663, -11.0031, -11.1028]]
)
torch.testing.assert_close(out.cpu().mean(-1), EXPECTED_MEAN, atol=1e-4, rtol=1e-4)
# change dtype to `torch.float32` before calling `mean` to avoid `nan` values
torch.testing.assert_close(out.cpu().to(torch.float32).mean(-1), EXPECTED_MEAN, atol=1e-4, rtol=1e-4)
# fmt: off
EXPECTED_SLICE = torch.tensor([-16.9670, -16.9647, -16.9649, -16.9630, -16.9577, -16.9623, -17.0164, -16.9673, -16.9648, -16.9668, -17.0160, -16.9651, -17.0156, -16.9668, -16.9655, -16.9653, -16.9665, -16.9682, -17.0112, -16.9667, -16.9717, -16.9654, -16.9650, -16.9701, -16.9657, -17.0160, -16.9676, -17.0138, -16.9610, -16.9695])
EXPECTED_SLICE = torch.tensor(
[-16.9062, -16.9062, -16.9062, -16.9062, -16.8906, -16.9062, -16.9531, -16.9062, -16.9062, -16.9062, -16.9531, -16.9062, -16.9531, -16.9062, -16.9062, -16.9062, -16.9062, -16.9062, -16.9531, -16.9062, -16.9062, -16.9062, -16.9062, -16.9062, -16.9062, -16.9531, -16.9062, -16.9531, -16.9062, -16.9062],
dtype=torch.float16
)
# fmt: on
torch.testing.assert_close(out.cpu()[0, 0, :30], EXPECTED_SLICE, atol=1e-5, rtol=1e-5)
torch.cuda.empty_cache()
del model
gc.collect()
@slow
@require_torch_gpu
def test_model_8b_chat_greedy_generation(self):
@ -415,11 +424,15 @@ class PersimmonIntegrationTest(unittest.TestCase):
prompt = "human: Simply put, the theory of relativity states that?\n\nadept:"
tokenizer = AutoTokenizer.from_pretrained("adept/persimmon-8b-chat", use_fast=False)
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(torch_device)
model = PersimmonForCausalLM.from_pretrained("adept/persimmon-8b-chat", torch_dtype=torch.float16).to(
torch_device
model = PersimmonForCausalLM.from_pretrained(
"adept/persimmon-8b-chat", load_in_8bit=True, device_map={"": 0}, torch_dtype=torch.float16
)
# greedy generation outputs
generated_ids = model.generate(input_ids, max_new_tokens=64)
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
torch.cuda.empty_cache()
del model
gc.collect()