TEST: Add llama logits tests (#30835)

* add llama logits test

* fix

* fix tests
"

"

* fix for a10

* format

* format

* fix

* [run-slow] remove fmt: skip

* Your commit message

* test commit

* Revert "test commit"

This reverts commit b66e01e55f.

* [run-slow]llama

* Update tests/models/llama/test_modeling_llama.py

* [run-slow]llama

* empty commit
This commit is contained in:
Younes Belkada 2024-05-17 12:23:00 +02:00 committed by GitHub
parent 15c74a2829
commit 3d7d3a87a0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 66 additions and 148 deletions

View File

@ -28,7 +28,6 @@ from transformers.testing_utils import (
require_flash_attn,
require_read_token,
require_torch,
require_torch_accelerator,
require_torch_gpu,
require_torch_sdpa,
slow,
@ -45,7 +44,6 @@ if is_torch_available():
import torch
from transformers import (
CodeLlamaTokenizer,
LlamaForCausalLM,
LlamaForQuestionAnswering,
LlamaForSequenceClassification,
@ -608,76 +606,79 @@ class LlamaIntegrationTest(unittest.TestCase):
# 8 is for A100 / A10 and 7 for T4
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
@unittest.skip("Logits are not exactly the same, once we fix the instabalities somehow, will update!")
@slow
@require_read_token
def test_model_7b_logits_bf16(self):
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338]
model = LlamaForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf", device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager"
)
with torch.no_grad():
out = model(torch.tensor([input_ids]).to(torch_device))
# Expected mean on dim = -1
# fmt: off
EXPECTED_MEAN = {
7: torch.tensor([[-6.5061, -4.1147, -4.9669, -3.2038, 0.8069, -2.9694, 1.2864, -3.3786]]),
8: torch.tensor([[-6.5208, -4.1218, -4.9377, -3.2536, 0.8127, -2.9811, 1.2918, -3.3848]])
}
self.assertTrue(torch.allclose(EXPECTED_MEAN[self.cuda_compute_capability_major_version].to(torch_device), out.logits.mean(-1), atol=1e-2, rtol=1e-2))
# slicing logits[0, 0, 0:15]
EXPECTED_SLICE = {
7: torch.tensor([[-12.5000, -7.0625, -0.6289, -7.8750, -6.9688, -7.8125, -6.4688, -7.4375, -7.6875, -6.9375, -6.0312, -7.0000, -1.8594, 1.8438, -8.5000]]),
8: torch.tensor([[-12.5625, -7.1250, -0.6289, -7.8750, -6.9688, -7.8125, -6.5000, -7.4375, -7.6562, -6.9688, -6.0312, -7.0312, -1.8203, 1.8750, -8.5000]])
}
# fmt: on
self.assertTrue(
torch.allclose(
EXPECTED_SLICE[self.cuda_compute_capability_major_version].to(torch_device),
out.logits[0, 0, :15],
atol=1e-3,
rtol=1e-3,
)
)
@slow
@require_read_token
def test_model_7b_logits(self):
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338]
model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", device_map="auto")
out = model(torch.tensor([input_ids]))
# Expected mean on dim = -1
EXPECTED_MEAN = torch.tensor([[-6.6550, -4.1227, -4.9859, -3.2406, 0.8262, -3.0033, 1.2964, -3.3699]])
torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, atol=1e-2, rtol=1e-2)
# slicing logits[0, 0, 0:30]
EXPECTED_SLICE = torch.tensor([-12.8281, -7.4453, -0.4639, -8.0625, -7.2500, -8.0000, -6.4883, -7.7695, -7.8438, -7.0312, -6.2188, -7.1328, -1.8496, 1.9961, -8.6250, -6.7227, -12.8281, -6.9492, -7.0742, -7.7852, -7.5820, -7.9062, -6.9375, -7.9805, -8.3438, -8.1562, -8.0469, -7.6250, -7.7422, -7.3398,]) # fmt: skip
torch.testing.assert_close(out[0, 0, :30], EXPECTED_SLICE, atol=1e-5, rtol=1e-5)
@unittest.skip("Logits are not exactly the same, once we fix the instabalities somehow, will update!")
@slow
def test_model_13b_logits(self):
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338]
model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-13b-hf", device_map="auto")
out = model(torch.tensor(input_ids))
# Expected mean on dim = -1
EXPECTED_MEAN = torch.tensor([[-2.0622, -1.2794, -1.1638, -0.9788, -1.4603, -1.0238, -1.7893, -1.4411]])
torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, atol=1e-2, rtol=1e-2)
# slicing logits[0, 0, 0:30]
EXPECTED_SLICE = torch.tensor([-8.1406, -8.0547, 2.7461, -1.2344, -0.1448, -1.8262, -1.0020, -1.8154, -1.6895, -1.8516, -2.3574, -0.9277, 3.7598, 6.5742, -1.2998, -0.1177, -8.1406, -2.9688, -2.9199, -3.1699, -3.5254, -2.3555, -2.7988, -3.4141, -2.8262, -4.5195, -3.3379, -3.3164, -2.7832, -3.0273]) # fmt: skip
torch.testing.assert_close(out[0, 0, :30], EXPECTED_SLICE, atol=1e-5, rtol=1e-5)
@unittest.skip("Logits are not exactly the same, once we fix the instabalities somehow, will update!")
@slow
def test_model_13bf_logits(self):
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338]
model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-13b-chat-hf", device_map="auto")
out = model(torch.tensor(input_ids))
# Expected mean on dim = -1
EXPECTED_MEAN = torch.tensor([[-0.8562, -1.8520, -0.7551, -0.4162, -1.5161, -1.2038, -2.4823, -2.3254]])
torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, atol=1e-2, rtol=1e-2)
# slicing logits[0, 0, 0:30]
EXPECTED_SLICE = torch.tensor([-2.2227, 4.8828, 0.9023, -0.4578, -0.7871, -0.1033, -0.6221, -0.5786, -0.7803, -1.0674, -1.2920, -0.1570, 0.8008, 2.0723, -0.9497, 0.2771, -2.2227, -0.7612, -1.4346, -1.2061, -1.6426, -0.3000, -0.7139, -1.1934, -1.8691, -1.6973, -1.5947, -1.2705, -0.3523, -0.5513]) # fmt: skip
torch.testing.assert_close(out.mean(-1), EXPECTED_SLICE, atol=1e-2, rtol=1e-2)
@unittest.skip(
"Logits are not exactly the same, once we fix the instabalities somehow, will update! Also it is gonna be a `too_slow` test"
)
@slow
def test_model_70b_logits(self):
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338]
model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-70b-hf", device_map="auto")
out = model(torch.tensor(input_ids))
EXPECTED_MEAN = torch.tensor(
[[-4.2327, -3.3360, -4.6665, -4.7631, -1.8180, -3.4170, -1.4211, -3.1810]], dtype=torch.float32
)
torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, atol=1e-2, rtol=1e-2)
EXPECTED_SLICE = torch.tensor([-9.4922, -3.9551, 1.7998, -5.6758, -5.1055, -5.8984, -4.8320, -6.8086, -6.5391, -5.6172, -5.5820, -5.5352, 1.7881, 3.6289, -6.5117, -3.4785, -9.5000, -6.0352, -6.8125, -6.0195, -6.6836, -5.4727, -6.2812, -6.0391, -7.3398, -7.4297, -7.4844, -6.5820, -5.8789, -5.5312]) # fmt: skip
torch.testing.assert_close(out[0, 0, :30], EXPECTED_SLICE, atol=1e-5, rtol=1e-5)
@unittest.skip("Model is curently gated")
@slow
def test_model_13b_greedy_generation(self):
EXPECTED_TEXT_COMPLETION = """Simply put, the theory of relativity states that 1) the laws of physics are the same everywhere in the universe and 2) the passage of time and the length of objects can vary depending on the observer\'s frame of reference.\n\nThe first part of the theory, that the laws of physics are the same everywhere, is known as the "princi"""
prompt = "Simply put, the theory of relativity states that "
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-13b-chat-hf")
input_ids = tokenizer.encode(prompt, return_tensors="pt")
model = LlamaForCausalLM.from_pretrained(
"meta-llama/Llama-2-13b-chat-hf", device_map="sequential", use_safetensors=False
"meta-llama/Llama-2-7b-hf", device_map="auto", torch_dtype=torch.float16
)
# greedy generation outputs
generated_ids = model.generate(input_ids, max_new_tokens=64, top_p=None, temperature=1, do_sample=False)
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
with torch.no_grad():
out = model(torch.tensor([input_ids]).to(torch_device))
# fmt: off
# Expected mean on dim = -1
EXPECTED_MEAN = {
7: torch.tensor([[-6.6420, -4.1227, -4.9809, -3.2041, 0.8261, -3.0052, 1.2957, -3.3648]]),
8: torch.tensor([[-6.6544, -4.1259, -4.9840, -3.2456, 0.8261, -3.0124, 1.2971, -3.3641]])
}
self.assertTrue(torch.allclose(EXPECTED_MEAN[self.cuda_compute_capability_major_version].to(torch_device), out.logits.mean(-1), atol=1e-2, rtol=1e-2))
# slicing logits[0, 0, 0:15]
EXPECTED_SLICE = {
7: torch.tensor([-12.8125, -7.3359, -0.4846, -8.0234, -7.2383, -7.9922, -6.4805, -7.7344, -7.8125, -7.0078, -6.1797, -7.1094, -1.8633, 1.9736, -8.6016]),
8: torch.tensor([-12.8281, -7.4609, -0.4668, -8.0703, -7.2539, -8.0078, -6.4961, -7.7734, -7.8516, -7.0352, -6.2188, -7.1367, -1.8564, 1.9922, -8.6328])
}
# fmt: on
self.assertTrue(
torch.allclose(
EXPECTED_SLICE[self.cuda_compute_capability_major_version].to(torch_device),
out.logits[0, 0, :15],
atol=1e-3,
rtol=1e-3,
)
)
@slow
@require_torch_gpu
@ -739,89 +740,6 @@ class LlamaIntegrationTest(unittest.TestCase):
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_compiled_text)
@require_torch
class CodeLlamaIntegrationTest(unittest.TestCase):
PROMPTS = [
'''def remove_non_ascii(s: str) -> str:
""" <FILL_ME>
return result
''',
"""# Installation instructions:
```bash
<FILL_ME>
```
This downloads the LLaMA inference code and installs the repository as a local pip package.
""",
"""class InterfaceManagerFactory(AbstractManagerFactory):
def __init__(<FILL_ME>
def main():
factory = InterfaceManagerFactory(start=datetime.now())
managers = []
for i in range(10):
managers.append(factory.build(id=i))
""",
"""/-- A quasi-prefunctoid is 1-connected iff all its etalisations are 1-connected. -/
theorem connected_iff_etalisation [C D : precategoroid] (P : quasi_prefunctoid C D) :
π₁ P = 0 <FILL_ME> = 0 :=
begin
split,
{ intros h f,
rw pi_1_etalisation at h,
simp [h],
refl
},
{ intro h,
have := @quasi_adjoint C D P,
simp [pi_1_etalisation, this, h],
refl
}
end
""",
]
@require_torch_accelerator
@slow
@unittest.skip("Model is too large")
def test_model_7b_logits(self):
model = LlamaForCausalLM.from_pretrained("codellama/CodeLlama-7b-hf").to(torch_device)
tokenizer = CodeLlamaTokenizer.from_pretrained("codellama/CodeLlama-7b-hf")
# Tokenize and prepare for the model a list of sequences or a list of pairs of sequences.
# meaning by default this supports passing splitted list of inputs
processed_text = tokenizer.batch_decode(tokenizer(self.PROMPTS)["input_ids"], add_special_tokens=False)
# fmt: off
EXPECTED_TEXT = [
'<s> <PRE> def remove_non_ascii(s: str) -> str:\n """ <SUF>\n return result\n <MID>',
'<s> <PRE> # Installation instructions:\n ```bash\n <SUF>\n ```\nThis downloads the LLaMA inference code and installs the repository as a local pip package.\n <MID>',
'<s> <PRE> class InterfaceManagerFactory(AbstractManagerFactory):\n def __init__( <SUF>\ndef main():\n factory = InterfaceManagerFactory(start=datetime.now())\n managers = []\n for i in range(10):\n managers.append(factory.build(id=i))\n <MID>',
'<s> <PRE> /-- A quasi-prefunctoid is 1-connected iff all its etalisations are 1-connected. -/\ntheorem connected_iff_etalisation [C D : precategoroid] (P : quasi_prefunctoid C D) :\nπ₁ P = 0 ↔ <SUF> = 0 :=\nbegin\nsplit,\n{ intros h f,\n rw pi_1_etalisation at h,\n simp [h],\n refl\n},\n{ intro h,\n have := @quasi_adjoint C D P,\n simp [←pi_1_etalisation, this, h],\n refl\n}\nend\n <MID>'
]
# fmt: on
self.assertEqual(processed_text, EXPECTED_TEXT)
processed_text_suffix_first = tokenizer.batch_decode(
tokenizer(self.PROMPTS, suffix_first=True, add_special_tokens=False)["input_ids"]
)
# fmt: off
EXPECTED_TEXT = [
'<PRE> <SUF>\n return result\n <MID> def remove_non_ascii(s: str) -> str:\n """ ',
'<PRE> <SUF>\n ```\nThis downloads the LLaMA inference code and installs the repository as a local pip package.\n <MID> # Installation instructions:\n ```bash\n',
'<PRE> <SUF>\ndef main():\n factory = InterfaceManagerFactory(start=datetime.now())\n managers = []\n for i in range(10):\n managers.append(factory.build(id=i))\n <MID> class InterfaceManagerFactory(AbstractManagerFactory):\n def __init__(',
'<PRE> <SUF> = 0 :=\nbegin\nsplit,\n{ intros h f,\n rw pi_1_etalisation at h,\n simp [h],\n refl\n},\n{ intro h,\n have := @quasi_adjoint C D P,\n simp [←pi_1_etalisation, this, h],\n refl\n}\nend\n <MID> /-- A quasi-prefunctoid is 1-connected iff all its etalisations are 1-connected. -/\ntheorem connected_iff_etalisation [C D : precategoroid] (P : quasi_prefunctoid C D) :\nπ₁ P = 0 ↔ '
]
EXPECTED_IDS = torch.tensor([[1, 32007, 822, 3349, 29918, 5464, 29918, 294, 18869, 29898, 29879, 29901, 851, 29897, 1599, 851, 29901, 13, 1678, 9995, 29871, 32008, 13, 1678, 736, 1121, 13, 32009, 15941, 1661, 29899, 28599, 2687, 4890, 515, 263, 1347, 29889, 13, 13, 1678, 826, 3174, 29901, 13, 4706, 269, 29901, 450, 1347, 304, 3349, 1661, 29899, 28599, 2687, 4890, 515, 29889, 13, 13, 1678, 16969, 29901, 13, 4706, 450, 1347, 411, 1661, 29899, 28599, 2687, 4890, 6206, 29889, 13, 1678, 9995, 13, 1678, 1121, 353, 5124, 13, 1678, 363, 274, 297, 269, 29901, 13, 4706, 565, 4356, 29898, 29883, 29897, 529, 29871, 29896, 29906, 29947, 29901, 13, 9651, 1121, 4619, 274, 32010, 2]])
# fmt: on
self.assertEqual(processed_text_suffix_first, EXPECTED_TEXT)
input_ids = tokenizer(self.PROMPTS[0], return_tensors="pt")["input_ids"]
generated_ids = model.generate(input_ids.to(torch_device), max_new_tokens=128)
torch.testing.assert_close(generated_ids, EXPECTED_IDS)
EXPECTED_INFILLING = [
'<s> <PRE> def remove_non_ascii(s: str) -> str:\n """ <SUF>\n return result\n <MID>Remove non-ASCII characters from a string.\n\n Args:\n s: The string to remove non-ASCII characters from.\n\n Returns:\n The string with non-ASCII characters removed.\n """\n result = ""\n for c in s:\n if ord(c) < 128:\n result += c <EOT></s>'
]
infilling = tokenizer.batch_decode(generated_ids)
self.assertEqual(infilling, EXPECTED_INFILLING)
@slow
@require_torch_gpu
class Mask4DTestHard(unittest.TestCase):