Fix `PersimmonIntegrationTest` OOM (#26750)
* fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
ab0ddc99e8
commit
72256bc72a
|
@ -15,6 +15,7 @@
|
|||
""" Testing suite for the PyTorch Persimmon model. """
|
||||
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
from parameterized import parameterized
|
||||
|
@ -395,19 +396,27 @@ class PersimmonIntegrationTest(unittest.TestCase):
|
|||
def test_model_8b_chat_logits(self):
|
||||
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338]
|
||||
model = PersimmonForCausalLM.from_pretrained(
|
||||
"adept/persimmon-8b-chat", device_map="auto", torch_dtype=torch.float16
|
||||
"adept/persimmon-8b-chat", load_in_8bit=True, device_map={"": 0}, torch_dtype=torch.float16
|
||||
)
|
||||
out = model(torch.tensor([input_ids])).logits
|
||||
out = model(torch.tensor([input_ids], device=torch_device)).logits
|
||||
|
||||
EXPECTED_MEAN = torch.tensor(
|
||||
[[-11.2879, -11.2628, -11.2498, -11.2534, -11.2676, -11.2638, -11.2501, -11.2431]], dtype=torch.float16
|
||||
[[-11.4726, -11.1495, -11.2694, -11.2223, -10.9452, -11.0663, -11.0031, -11.1028]]
|
||||
)
|
||||
torch.testing.assert_close(out.cpu().mean(-1), EXPECTED_MEAN, atol=1e-4, rtol=1e-4)
|
||||
# change dtype to `torch.float32` before calling `mean` to avoid `nan` values
|
||||
torch.testing.assert_close(out.cpu().to(torch.float32).mean(-1), EXPECTED_MEAN, atol=1e-4, rtol=1e-4)
|
||||
# fmt: off
|
||||
EXPECTED_SLICE = torch.tensor([-16.9670, -16.9647, -16.9649, -16.9630, -16.9577, -16.9623, -17.0164, -16.9673, -16.9648, -16.9668, -17.0160, -16.9651, -17.0156, -16.9668, -16.9655, -16.9653, -16.9665, -16.9682, -17.0112, -16.9667, -16.9717, -16.9654, -16.9650, -16.9701, -16.9657, -17.0160, -16.9676, -17.0138, -16.9610, -16.9695])
|
||||
EXPECTED_SLICE = torch.tensor(
|
||||
[-16.9062, -16.9062, -16.9062, -16.9062, -16.8906, -16.9062, -16.9531, -16.9062, -16.9062, -16.9062, -16.9531, -16.9062, -16.9531, -16.9062, -16.9062, -16.9062, -16.9062, -16.9062, -16.9531, -16.9062, -16.9062, -16.9062, -16.9062, -16.9062, -16.9062, -16.9531, -16.9062, -16.9531, -16.9062, -16.9062],
|
||||
dtype=torch.float16
|
||||
)
|
||||
# fmt: on
|
||||
torch.testing.assert_close(out.cpu()[0, 0, :30], EXPECTED_SLICE, atol=1e-5, rtol=1e-5)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
del model
|
||||
gc.collect()
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
def test_model_8b_chat_greedy_generation(self):
|
||||
|
@ -415,11 +424,15 @@ class PersimmonIntegrationTest(unittest.TestCase):
|
|||
prompt = "human: Simply put, the theory of relativity states that?\n\nadept:"
|
||||
tokenizer = AutoTokenizer.from_pretrained("adept/persimmon-8b-chat", use_fast=False)
|
||||
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(torch_device)
|
||||
model = PersimmonForCausalLM.from_pretrained("adept/persimmon-8b-chat", torch_dtype=torch.float16).to(
|
||||
torch_device
|
||||
model = PersimmonForCausalLM.from_pretrained(
|
||||
"adept/persimmon-8b-chat", load_in_8bit=True, device_map={"": 0}, torch_dtype=torch.float16
|
||||
)
|
||||
|
||||
# greedy generation outputs
|
||||
generated_ids = model.generate(input_ids, max_new_tokens=64)
|
||||
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
del model
|
||||
gc.collect()
|
||||
|
|
Loading…
Reference in New Issue