Fix slow tests for important models to be compatible with A10 runners (#29905)
* fix mistral and mixtral * add pdb * fix mixtral tesst * fix * fix mistral ? * add fix gemma * fix mistral * fix * test * anoter test * fix * fix * fix mistral tests * fix them again * final fixes for mistral * fix padding right * fix whipser fa2 * fix * fix * fix gemma * test * fix llama * fix * fix * fix llama gemma * add class attribute * fix CI * clarify whisper * compute_capability * rename names in some comments * Add # fmt: skip * make style * Update tests/models/mistral/test_modeling_mistral.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * update * update --------- Co-authored-by: Younes Belkada <younesbelkada@gmail.com> Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
parent
e9c23fa056
commit
08a194fcd6
|
@ -21,6 +21,7 @@ from parameterized import parameterized
|
|||
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, GemmaConfig, is_torch_available
|
||||
from transformers.testing_utils import (
|
||||
is_flaky,
|
||||
require_bitsandbytes,
|
||||
require_flash_attn,
|
||||
require_read_token,
|
||||
|
@ -379,40 +380,6 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||
def test_past_key_values_format(self):
|
||||
pass
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@pytest.mark.flash_attn_test
|
||||
@slow
|
||||
def test_flash_attn_2_generate_padding_right(self):
|
||||
import torch
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
|
||||
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device)
|
||||
|
||||
model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False)
|
||||
|
||||
model = model_class.from_pretrained(
|
||||
tmpdirname,
|
||||
torch_dtype=torch.float16,
|
||||
attn_implementation="flash_attention_2",
|
||||
low_cpu_mem_usage=True,
|
||||
).to(torch_device)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
_ = model.generate(
|
||||
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
|
||||
)
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@pytest.mark.flash_attn_test
|
||||
|
@ -500,6 +467,7 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@pytest.mark.flash_attn_test
|
||||
@is_flaky
|
||||
@slow
|
||||
def test_flash_attn_2_equivalence(self):
|
||||
for model_class in self.all_model_classes:
|
||||
|
@ -531,12 +499,21 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||
assert torch.allclose(logits_fa, logits, atol=3e-3)
|
||||
|
||||
|
||||
@require_torch_gpu
|
||||
@slow
|
||||
@require_read_token
|
||||
@require_torch_gpu
|
||||
class GemmaIntegrationTest(unittest.TestCase):
|
||||
input_text = ["Hello I am doing", "Hi today"]
|
||||
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
|
||||
# Depending on the hardware we get different logits / generations
|
||||
cuda_compute_capability_major_version = None
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if is_torch_available() and torch.cuda.is_available():
|
||||
# 8 is for A100 / A10 and 7 for T4
|
||||
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
|
||||
|
||||
@require_read_token
|
||||
def test_model_2b_fp32(self):
|
||||
model_id = "google/gemma-2b"
|
||||
EXPECTED_TEXTS = [
|
||||
|
@ -554,6 +531,7 @@ class GemmaIntegrationTest(unittest.TestCase):
|
|||
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||
|
||||
@require_read_token
|
||||
def test_model_2b_fp16(self):
|
||||
model_id = "google/gemma-2b"
|
||||
EXPECTED_TEXTS = [
|
||||
|
@ -573,6 +551,7 @@ class GemmaIntegrationTest(unittest.TestCase):
|
|||
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||
|
||||
@require_read_token
|
||||
def test_model_2b_fp16_static_cache(self):
|
||||
model_id = "google/gemma-2b"
|
||||
EXPECTED_TEXTS = [
|
||||
|
@ -594,12 +573,19 @@ class GemmaIntegrationTest(unittest.TestCase):
|
|||
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||
|
||||
@require_read_token
|
||||
def test_model_2b_bf16(self):
|
||||
model_id = "google/gemma-2b"
|
||||
EXPECTED_TEXTS = [
|
||||
"Hello I am doing a project on the 1990s and I need to know what the most popular music",
|
||||
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Khichdi",
|
||||
]
|
||||
EXPECTED_TEXTS = {
|
||||
7: [
|
||||
"Hello I am doing a project on the 1990s and I need to know what the most popular music",
|
||||
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Khichdi",
|
||||
],
|
||||
8: [
|
||||
"Hello I am doing a project on the 1990s and I need to know what the most popular music",
|
||||
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Kaju Kat",
|
||||
],
|
||||
}
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16).to(
|
||||
torch_device
|
||||
|
@ -611,14 +597,21 @@ class GemmaIntegrationTest(unittest.TestCase):
|
|||
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
|
||||
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS[self.cuda_compute_capability_major_version])
|
||||
|
||||
@require_read_token
|
||||
def test_model_2b_eager(self):
|
||||
model_id = "google/gemma-2b"
|
||||
EXPECTED_TEXTS = [
|
||||
"Hello I am doing a project on the 1990s and I am looking for some information on the ",
|
||||
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Kaju Kat",
|
||||
]
|
||||
EXPECTED_TEXTS = {
|
||||
7: [
|
||||
"Hello I am doing a project on the 1990s and I am looking for some information on the ",
|
||||
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Kaju Kat",
|
||||
],
|
||||
8: [
|
||||
"Hello I am doing a project on the 1990s and I need to know what the most popular music",
|
||||
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Kaju Kat",
|
||||
],
|
||||
}
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="eager"
|
||||
|
@ -631,15 +624,22 @@ class GemmaIntegrationTest(unittest.TestCase):
|
|||
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
|
||||
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS[self.cuda_compute_capability_major_version])
|
||||
|
||||
@require_torch_sdpa
|
||||
@require_read_token
|
||||
def test_model_2b_sdpa(self):
|
||||
model_id = "google/gemma-2b"
|
||||
EXPECTED_TEXTS = [
|
||||
"Hello I am doing a project on the 1990s and I need to know what the most popular music",
|
||||
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Khichdi",
|
||||
]
|
||||
EXPECTED_TEXTS = {
|
||||
7: [
|
||||
"Hello I am doing a project on the 1990s and I need to know what the most popular music",
|
||||
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Khichdi",
|
||||
],
|
||||
8: [
|
||||
"Hello I am doing a project on the 1990s and I need to know what the most popular music",
|
||||
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Kaju Kat",
|
||||
],
|
||||
}
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="sdpa"
|
||||
|
@ -652,10 +652,11 @@ class GemmaIntegrationTest(unittest.TestCase):
|
|||
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
|
||||
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS[self.cuda_compute_capability_major_version])
|
||||
|
||||
@pytest.mark.flash_attn_test
|
||||
@require_flash_attn
|
||||
@require_read_token
|
||||
def test_model_2b_flash_attn(self):
|
||||
model_id = "google/gemma-2b"
|
||||
EXPECTED_TEXTS = [
|
||||
|
@ -677,6 +678,7 @@ class GemmaIntegrationTest(unittest.TestCase):
|
|||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||
|
||||
@require_bitsandbytes
|
||||
@require_read_token
|
||||
def test_model_2b_4bit(self):
|
||||
model_id = "google/gemma-2b"
|
||||
EXPECTED_TEXTS = [
|
||||
|
@ -695,6 +697,7 @@ class GemmaIntegrationTest(unittest.TestCase):
|
|||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||
|
||||
@unittest.skip("The test will not fit our CI runners")
|
||||
@require_read_token
|
||||
def test_model_7b_fp32(self):
|
||||
model_id = "google/gemma-7b"
|
||||
EXPECTED_TEXTS = [
|
||||
|
@ -712,6 +715,7 @@ class GemmaIntegrationTest(unittest.TestCase):
|
|||
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||
|
||||
@require_read_token
|
||||
def test_model_7b_fp16(self):
|
||||
model_id = "google/gemma-7b"
|
||||
EXPECTED_TEXTS = [
|
||||
|
@ -731,12 +735,19 @@ class GemmaIntegrationTest(unittest.TestCase):
|
|||
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||
|
||||
@require_read_token
|
||||
def test_model_7b_bf16(self):
|
||||
model_id = "google/gemma-7b"
|
||||
EXPECTED_TEXTS = [
|
||||
"""Hello I am doing a project on a 1991 240sx and I am trying to find""",
|
||||
"Hi today I am going to show you how to make a very simple and easy to make a very simple and",
|
||||
]
|
||||
EXPECTED_TEXTS = {
|
||||
7: [
|
||||
"""Hello I am doing a project on a 1991 240sx and I am trying to find""",
|
||||
"Hi today I am going to show you how to make a very simple and easy to make a very simple and",
|
||||
],
|
||||
8: [
|
||||
"Hello I am doing a project for my school and I am trying to make a program that will read a .txt file",
|
||||
"Hi today I am going to show you how to make a very simple and easy to make a very simple and",
|
||||
],
|
||||
}
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16).to(
|
||||
torch_device
|
||||
|
@ -748,8 +759,9 @@ class GemmaIntegrationTest(unittest.TestCase):
|
|||
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
|
||||
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS[self.cuda_compute_capability_major_version])
|
||||
|
||||
@require_read_token
|
||||
def test_model_7b_fp16_static_cache(self):
|
||||
model_id = "google/gemma-7b"
|
||||
EXPECTED_TEXTS = [
|
||||
|
@ -772,12 +784,19 @@ class GemmaIntegrationTest(unittest.TestCase):
|
|||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||
|
||||
@require_bitsandbytes
|
||||
@require_read_token
|
||||
def test_model_7b_4bit(self):
|
||||
model_id = "google/gemma-7b"
|
||||
EXPECTED_TEXTS = [
|
||||
"Hello I am doing a project for my school and I am trying to make a program that will take a number and then",
|
||||
"""Hi today I am going to talk about the new update for the game called "The new update" and I""",
|
||||
]
|
||||
EXPECTED_TEXTS = {
|
||||
7: [
|
||||
"Hello I am doing a project for my school and I am trying to make a program that will take a number and then",
|
||||
"""Hi today I am going to talk about the new update for the game called "The new update" and I""",
|
||||
],
|
||||
8: [
|
||||
"Hello I am doing a project for my school and I am trying to make a program that will take a number and then",
|
||||
"Hi today I am going to talk about the best way to get rid of acne. miniaturing is a very",
|
||||
],
|
||||
}
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, load_in_4bit=True)
|
||||
|
||||
|
@ -787,4 +806,4 @@ class GemmaIntegrationTest(unittest.TestCase):
|
|||
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
|
||||
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS[self.cuda_compute_capability_major_version])
|
||||
|
|
|
@ -597,8 +597,18 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_torch_gpu
|
||||
class LlamaIntegrationTest(unittest.TestCase):
|
||||
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
|
||||
# Depending on the hardware we get different logits / generations
|
||||
cuda_compute_capability_major_version = None
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if is_torch_available() and torch.cuda.is_available():
|
||||
# 8 is for A100 / A10 and 7 for T4
|
||||
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
|
||||
|
||||
@unittest.skip("Logits are not exactly the same, once we fix the instabalities somehow, will update!")
|
||||
@slow
|
||||
def test_model_7b_logits(self):
|
||||
|
@ -675,16 +685,25 @@ class LlamaIntegrationTest(unittest.TestCase):
|
|||
@require_read_token
|
||||
def test_compile_static_cache(self):
|
||||
NUM_TOKENS_TO_GENERATE = 40
|
||||
EXPECTED_TEXT_COMPLETION = [
|
||||
"Simply put, the theory of relativity states that 1) the speed of light is constant, 2) the speed of light is the same for all observers, and 3) the laws of physics are the same for all observers.",
|
||||
"My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p",
|
||||
]
|
||||
EXPECTED_TEXT_COMPLETION = {
|
||||
7: [
|
||||
"Simply put, the theory of relativity states that 1) the speed of light is constant, 2) the speed of light is the same for all observers, and 3) the laws of physics are the same for all observers.",
|
||||
"My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p",
|
||||
],
|
||||
8: [
|
||||
"Simply put, the theory of relativity states that 1) the speed of light is the same for all observers, and 2) the laws of physics are the same for all observers.\nThe first part of the theory of relativity",
|
||||
"My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p",
|
||||
],
|
||||
}
|
||||
|
||||
prompts = [
|
||||
"Simply put, the theory of relativity states that ",
|
||||
"My favorite all time favorite condiment is ketchup.",
|
||||
]
|
||||
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token="</s>", padding_side="right")
|
||||
model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", device_map="sequential")
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-2-7b-hf", device_map="sequential", torch_dtype=torch.float16
|
||||
)
|
||||
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
|
||||
|
||||
def decode_one_tokens(model, cur_token, input_pos, cache_position):
|
||||
|
@ -718,7 +737,7 @@ class LlamaIntegrationTest(unittest.TestCase):
|
|||
cache_position += 1
|
||||
|
||||
text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], text)
|
||||
|
||||
|
||||
@require_torch
|
||||
|
@ -763,6 +782,7 @@ end
|
|||
|
||||
@require_torch_accelerator
|
||||
@slow
|
||||
@unittest.skip("Model is too large")
|
||||
def test_model_7b_logits(self):
|
||||
model = LlamaForCausalLM.from_pretrained("codellama/CodeLlama-7b-hf").to(torch_device)
|
||||
tokenizer = CodeLlamaTokenizer.from_pretrained("codellama/CodeLlama-7b-hf")
|
||||
|
|
|
@ -470,39 +470,68 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||
self.skipTest("Mistral flash attention does not support right padding")
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_torch_gpu
|
||||
class MistralIntegrationTest(unittest.TestCase):
|
||||
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
|
||||
# Depending on the hardware we get different logits / generations
|
||||
cuda_compute_capability_major_version = None
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if is_torch_available() and torch.cuda.is_available():
|
||||
# 8 is for A100 / A10 and 7 for T4
|
||||
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
|
||||
|
||||
def tearDown(self):
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
@slow
|
||||
def test_model_7b_logits(self):
|
||||
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338]
|
||||
model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", device_map="auto")
|
||||
model = MistralForCausalLM.from_pretrained(
|
||||
"mistralai/Mistral-7B-v0.1", device_map="auto", torch_dtype=torch.float16
|
||||
)
|
||||
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
|
||||
with torch.no_grad():
|
||||
out = model(input_ids).logits.cpu()
|
||||
# Expected mean on dim = -1
|
||||
EXPECTED_MEAN = torch.tensor([[-2.5548, -2.5737, -3.0600, -2.5906, -2.8478, -2.8118, -2.9325, -2.7694]])
|
||||
torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, atol=1e-2, rtol=1e-2)
|
||||
# slicing logits[0, 0, 0:30]
|
||||
EXPECTED_SLICE = torch.tensor([-5.8781, -5.8616, -0.1052, -4.7200, -5.8781, -5.8774, -5.8773, -5.8777, -5.8781, -5.8780, -5.8781, -5.8779, -1.0787, 1.7583, -5.8779, -5.8780, -5.8783, -5.8778, -5.8776, -5.8781, -5.8784, -5.8778, -5.8778, -5.8777, -5.8779, -5.8778, -5.8776, -5.8780, -5.8779, -5.8781]) # fmt: skip
|
||||
|
||||
EXPECTED_SLICE = {
|
||||
7: torch.tensor([-5.8781, -5.8616, -0.1052, -4.7200, -5.8781, -5.8774, -5.8773, -5.8777, -5.8781, -5.8780, -5.8781, -5.8779, -1.0787, 1.7583, -5.8779, -5.8780, -5.8783, -5.8778, -5.8776, -5.8781, -5.8784, -5.8778, -5.8778, -5.8777, -5.8779, -5.8778, -5.8776, -5.8780, -5.8779, -5.8781]),
|
||||
8: torch.tensor([-5.8711, -5.8555, -0.1050, -4.7148, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -1.0781, 1.7568, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711]),
|
||||
} # fmt: skip
|
||||
|
||||
print(out[0, 0, :30])
|
||||
torch.testing.assert_close(out[0, 0, :30], EXPECTED_SLICE, atol=1e-4, rtol=1e-4)
|
||||
torch.testing.assert_close(
|
||||
out[0, 0, :30], EXPECTED_SLICE[self.cuda_compute_capability_major_version], atol=1e-4, rtol=1e-4
|
||||
)
|
||||
|
||||
del model
|
||||
backend_empty_cache(torch_device)
|
||||
gc.collect()
|
||||
|
||||
@slow
|
||||
@require_bitsandbytes
|
||||
def test_model_7b_generation(self):
|
||||
EXPECTED_TEXT_COMPLETION = """My favourite condiment is 100% ketchup. I love it on everything. I’m not a big"""
|
||||
EXPECTED_TEXT_COMPLETION = {
|
||||
7: "My favourite condiment is 100% ketchup. I love it on everything. I'm not a big",
|
||||
8: "My favourite condiment is 100% ketchup. I’m not a fan of mustard, mayo,",
|
||||
}
|
||||
|
||||
prompt = "My favourite condiment is "
|
||||
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=False)
|
||||
model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", device_map="auto")
|
||||
model = MistralForCausalLM.from_pretrained(
|
||||
"mistralai/Mistral-7B-v0.1", device_map={"": torch_device}, load_in_4bit=True
|
||||
)
|
||||
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.model.embed_tokens.weight.device)
|
||||
|
||||
# greedy generation outputs
|
||||
generated_ids = model.generate(input_ids, max_new_tokens=20, temperature=0)
|
||||
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], text)
|
||||
|
||||
del model
|
||||
backend_empty_cache(torch_device)
|
||||
|
@ -517,7 +546,7 @@ class MistralIntegrationTest(unittest.TestCase):
|
|||
input_ids = [1] + [306, 338] * 2048
|
||||
model = MistralForCausalLM.from_pretrained(
|
||||
"mistralai/Mistral-7B-v0.1",
|
||||
device_map="auto",
|
||||
device_map={"": torch_device},
|
||||
load_in_4bit=True,
|
||||
attn_implementation="flash_attention_2",
|
||||
)
|
||||
|
@ -544,9 +573,7 @@ class MistralIntegrationTest(unittest.TestCase):
|
|||
# An input with 4097 tokens that is above the size of the sliding window
|
||||
input_ids = [1] + [306, 338] * 2048
|
||||
model = MistralForCausalLM.from_pretrained(
|
||||
"mistralai/Mistral-7B-v0.1",
|
||||
device_map="auto",
|
||||
attn_implementation="sdpa",
|
||||
"mistralai/Mistral-7B-v0.1", device_map="auto", attn_implementation="sdpa", torch_dtype=torch.float16
|
||||
)
|
||||
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
|
||||
generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0)
|
||||
|
@ -577,9 +604,10 @@ class MistralIntegrationTest(unittest.TestCase):
|
|||
|
||||
@slow
|
||||
def test_speculative_generation(self):
|
||||
EXPECTED_TEXT_COMPLETION = (
|
||||
"My favourite condiment is 100% Sriracha. I love the heat, the tang and the fact costs"
|
||||
)
|
||||
EXPECTED_TEXT_COMPLETION = {
|
||||
7: "My favourite condiment is 100% Sriracha. I love the heat, the tang and the fact costs",
|
||||
8: "My favourite condiment is 100% Sriracha. I love the heat, the sweetness, the tang",
|
||||
}
|
||||
prompt = "My favourite condiment is "
|
||||
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=False)
|
||||
model = MistralForCausalLM.from_pretrained(
|
||||
|
@ -593,7 +621,7 @@ class MistralIntegrationTest(unittest.TestCase):
|
|||
input_ids, max_new_tokens=20, do_sample=True, temperature=0.3, assistant_model=model
|
||||
)
|
||||
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], text)
|
||||
|
||||
del model
|
||||
backend_empty_cache(torch_device)
|
||||
|
|
|
@ -507,6 +507,16 @@ class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||
|
||||
@require_torch
|
||||
class MixtralIntegrationTest(unittest.TestCase):
|
||||
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
|
||||
# Depending on the hardware we get different logits / generations
|
||||
cuda_compute_capability_major_version = None
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if is_torch_available() and torch.cuda.is_available():
|
||||
# 8 is for A100 / A10 and 7 for T4
|
||||
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
def test_small_model_logits(self):
|
||||
|
@ -518,18 +528,26 @@ class MixtralIntegrationTest(unittest.TestCase):
|
|||
)
|
||||
# TODO: might need to tweak it in case the logits do not match on our daily runners
|
||||
# these logits have been obtained with the original megablocks impelmentation.
|
||||
EXPECTED_LOGITS = torch.Tensor(
|
||||
[[0.1670, 0.1620, 0.6094], [-0.8906, -0.1588, -0.6060], [0.1572, 0.1290, 0.7246]]
|
||||
).to(torch_device)
|
||||
|
||||
EXPECTED_LOGITS = {
|
||||
7: torch.Tensor([[0.1670, 0.1620, 0.6094], [-0.8906, -0.1588, -0.6060], [0.1572, 0.1290, 0.7246]]).to(
|
||||
torch_device
|
||||
),
|
||||
8: torch.Tensor([[0.1631, 0.1621, 0.6094], [-0.8906, -0.1621, -0.6094], [0.1572, 0.1270, 0.7227]]).to(
|
||||
torch_device
|
||||
),
|
||||
}
|
||||
with torch.no_grad():
|
||||
logits = model(dummy_input).logits
|
||||
|
||||
torch.testing.assert_close(logits[0, :3, :3].half(), EXPECTED_LOGITS, atol=1e-3, rtol=1e-3)
|
||||
torch.testing.assert_close(logits[1, :3, :3].half(), EXPECTED_LOGITS, atol=1e-3, rtol=1e-3)
|
||||
torch.testing.assert_close(
|
||||
logits[0, :3, :3], EXPECTED_LOGITS[self.cuda_compute_capability_major_version], atol=1e-3, rtol=1e-3
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
logits[1, :3, :3], EXPECTED_LOGITS[self.cuda_compute_capability_major_version], atol=1e-3, rtol=1e-3
|
||||
)
|
||||
|
||||
@slow
|
||||
# @require_torch_gpu
|
||||
@require_torch_gpu
|
||||
def test_small_model_logits_batched(self):
|
||||
model_id = "hf-internal-testing/Mixtral-tiny"
|
||||
dummy_input = torch.LongTensor([[0, 0, 0, 0, 0, 0, 1, 2, 3], [1, 1, 2, 3, 4, 5, 6, 7, 8]]).to(torch_device)
|
||||
|
@ -540,23 +558,48 @@ class MixtralIntegrationTest(unittest.TestCase):
|
|||
)
|
||||
|
||||
# TODO: might need to tweak it in case the logits do not match on our daily runners
|
||||
EXPECTED_LOGITS_LEFT = torch.Tensor(
|
||||
[[0.1750, 0.0537, 0.7007], [0.1750, 0.0537, 0.7007], [0.1750, 0.0537, 0.7007]],
|
||||
)
|
||||
EXPECTED_LOGITS_LEFT = {
|
||||
7: torch.Tensor(
|
||||
[[0.1750, 0.0537, 0.7007], [0.1750, 0.0537, 0.7007], [0.1750, 0.0537, 0.7007]],
|
||||
).to(torch_device),
|
||||
8: torch.Tensor([[0.1914, 0.0508, 0.7188], [0.1953, 0.0510, 0.7227], [0.1973, 0.0562, 0.7148]]).to(
|
||||
torch_device
|
||||
),
|
||||
}
|
||||
|
||||
# logits[0, -3:, -3:].half()
|
||||
EXPECTED_LOGITS_LEFT_UNPADDED = torch.Tensor(
|
||||
[[0.2212, 0.5200, -0.3816], [0.8213, -0.2313, 0.6069], [0.2664, -0.7090, 0.2468]],
|
||||
)
|
||||
EXPECTED_LOGITS_LEFT_UNPADDED = {
|
||||
7: torch.Tensor(
|
||||
[[0.2212, 0.5200, -0.3816], [0.8213, -0.2313, 0.6069], [0.2664, -0.7090, 0.2468]],
|
||||
).to(torch_device),
|
||||
8: torch.Tensor([[0.2217, 0.5195, -0.3828], [0.8203, -0.2295, 0.6055], [0.2676, -0.7109, 0.2461]]).to(
|
||||
torch_device
|
||||
),
|
||||
}
|
||||
|
||||
# logits[1, -3:, -3:].half()
|
||||
EXPECTED_LOGITS_RIGHT_UNPADDED = torch.Tensor(
|
||||
[[0.2205, 0.1232, -0.1611], [-0.3484, 0.3030, -1.0312], [0.0742, 0.7930, 0.7969]]
|
||||
)
|
||||
EXPECTED_LOGITS_RIGHT_UNPADDED = {
|
||||
7: torch.Tensor([[0.2205, 0.1232, -0.1611], [-0.3484, 0.3030, -1.0312], [0.0742, 0.7930, 0.7969]]).to(
|
||||
torch_device
|
||||
),
|
||||
8: torch.Tensor([[0.2178, 0.1260, -0.1621], [-0.3496, 0.2988, -1.0312], [0.0693, 0.7930, 0.8008]]).to(
|
||||
torch_device
|
||||
),
|
||||
}
|
||||
|
||||
with torch.no_grad():
|
||||
logits = model(dummy_input, attention_mask=attention_mask).logits
|
||||
|
||||
torch.testing.assert_close(logits[0, :3, :3].half(), EXPECTED_LOGITS_LEFT, atol=1e-3, rtol=1e-3)
|
||||
torch.testing.assert_close(logits[0, -3:, -3:].half(), EXPECTED_LOGITS_LEFT_UNPADDED, atol=1e-3, rtol=1e-3)
|
||||
torch.testing.assert_close(logits[1, -3:, -3:].half(), EXPECTED_LOGITS_RIGHT_UNPADDED, atol=1e-3, rtol=1e-3)
|
||||
torch.testing.assert_close(
|
||||
logits[0, :3, :3], EXPECTED_LOGITS_LEFT[self.cuda_compute_capability_major_version], atol=1e-3, rtol=1e-3
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
logits[0, -3:, -3:],
|
||||
EXPECTED_LOGITS_LEFT_UNPADDED[self.cuda_compute_capability_major_version],
|
||||
atol=1e-3,
|
||||
rtol=1e-3,
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
logits[1, -3:, -3:],
|
||||
EXPECTED_LOGITS_RIGHT_UNPADDED[self.cuda_compute_capability_major_version],
|
||||
atol=1e-3,
|
||||
rtol=1e-3,
|
||||
)
|
||||
|
|
|
@ -3339,3 +3339,21 @@ class WhisperStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin,
|
|||
@unittest.skip("The model doesn't support fast init from base")
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
"Duplicated test with WhisperModelTest + the FA2 testing suite needs to be refactored to be compatible with WhisperDecoder for that test"
|
||||
)
|
||||
def test_flash_attn_2_generate_padding_right(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
"Duplicated test with WhisperModelTest + the FA2 testing suite needs to be refactored to be compatible with WhisperDecoder for that test"
|
||||
)
|
||||
def test_flash_attn_2_inference(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
"Duplicated test with WhisperModelTest + the FA2 testing suite needs to be refactored to be compatible with WhisperDecoder for that test"
|
||||
)
|
||||
def test_flash_attn_2_inference_padding_right(self):
|
||||
pass
|
||||
|
|
|
@ -3245,6 +3245,7 @@ class ModelTesterMixin:
|
|||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
@slow
|
||||
@is_flaky
|
||||
def test_flash_attn_2_inference_equivalence(self):
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
|
@ -3338,6 +3339,7 @@ class ModelTesterMixin:
|
|||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
@slow
|
||||
@is_flaky
|
||||
def test_flash_attn_2_inference_equivalence_right_padding(self):
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
|
@ -3427,6 +3429,7 @@ class ModelTesterMixin:
|
|||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
@slow
|
||||
@is_flaky
|
||||
def test_flash_attn_2_generate_left_padding(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
|
@ -3470,6 +3473,7 @@ class ModelTesterMixin:
|
|||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
@is_flaky
|
||||
@slow
|
||||
def test_flash_attn_2_generate_padding_right(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
|
@ -3888,19 +3892,20 @@ class ModelTesterMixin:
|
|||
for model_class in self.all_generative_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
|
||||
dummy_input = inputs_dict[model.main_input_name]
|
||||
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
|
||||
batch_size = dummy_attention_mask.shape[0]
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
dummy_decoder_input_ids = inputs_dict["decoder_input_ids"]
|
||||
dummy_decoder_attention_mask = inputs_dict["decoder_attention_mask"]
|
||||
is_padding_right = dummy_attention_mask[:, -1].sum().item() != batch_size
|
||||
|
||||
# To avoid errors with padding_side=="right"
|
||||
if is_padding_right:
|
||||
dummy_attention_mask = torch.ones_like(dummy_input)
|
||||
|
||||
model = model_class.from_pretrained(
|
||||
tmpdirname,
|
||||
|
@ -3916,6 +3921,9 @@ class ModelTesterMixin:
|
|||
param.data = param.data.to(torch.float32)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
dummy_decoder_input_ids = inputs_dict["decoder_input_ids"]
|
||||
dummy_decoder_attention_mask = inputs_dict["decoder_attention_mask"]
|
||||
|
||||
_ = model(dummy_input, decoder_input_ids=dummy_decoder_input_ids)
|
||||
# with attention mask
|
||||
_ = model(
|
||||
|
|
Loading…
Reference in New Issue