diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 55dabe7cbe..165ef5a054 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -276,7 +276,7 @@ class GemmaAttention(nn.Module): attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = attn_output.view(bsz, q_len, -1) attn_output = self.o_proj(attn_output) if not output_attentions: diff --git a/tests/models/gemma/test_modeling_gemma.py b/tests/models/gemma/test_modeling_gemma.py index 1c368a02bc..670519d2a1 100644 --- a/tests/models/gemma/test_modeling_gemma.py +++ b/tests/models/gemma/test_modeling_gemma.py @@ -26,6 +26,7 @@ from transformers.testing_utils import ( require_flash_attn, require_torch, require_torch_gpu, + require_torch_sdpa, slow, torch_device, ) @@ -460,6 +461,71 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi def test_flash_attn_2_inference_padding_right(self): self.skipTest("Gemma flash attention does not support right padding") + @require_torch_sdpa + @require_torch_gpu + @slow + def test_sdpa_equivalence(self): + for model_class in self.all_model_classes: + if not model_class._supports_sdpa: + return + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_sdpa = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.float16, attn_implementation="sdpa" + ) + model_sdpa.to(torch_device) + + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, attn_implementation="eager") + model.to(torch_device) + + dummy_input = inputs_dict[model_class.main_input_name] + dummy_input = dummy_input.to(torch_device) + outputs = model(dummy_input, output_hidden_states=True) + outputs_sdpa = model_sdpa(dummy_input, output_hidden_states=True) + + logits = outputs.hidden_states[-1] + logits_sdpa = outputs_sdpa.hidden_states[-1] + + # gemma sdpa needs a high tolerance + assert torch.allclose(logits_sdpa, logits, atol=3e-3) + + @require_flash_attn + @require_torch_gpu + @pytest.mark.flash_attn_test + @slow + def test_flash_attn_2_equivalence(self): + for model_class in self.all_model_classes: + if not model_class._supports_flash_attn_2: + return + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_fa = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.float16, attn_implementation="flash_attention_2" + ) + model_fa.to(torch_device) + + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, attn_implementation="eager") + model.to(torch_device) + + dummy_input = inputs_dict[model_class.main_input_name] + dummy_input = dummy_input.to(torch_device) + outputs = model(dummy_input, output_hidden_states=True) + outputs_fa = model_fa(dummy_input, output_hidden_states=True) + + logits = outputs.hidden_states[-1] + logits_fa = outputs_fa.hidden_states[-1] + + # gemma flash attention 2 needs a high tolerance + assert torch.allclose(logits_fa, logits, atol=3e-3) + @require_torch_gpu @slow @@ -542,6 +608,69 @@ class GemmaIntegrationTest(unittest.TestCase): self.assertEqual(output_text, EXPECTED_TEXTS) + def test_model_2b_eager(self): + model_id = "google/gemma-2b" + EXPECTED_TEXTS = [ + "Hello I am doing a project on the 1990s and I am looking for some information on the ", + "Hi today I am going to share with you a very easy and simple recipe of Kaju Kat", + ] + + model = AutoModelForCausalLM.from_pretrained( + model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="eager" + ) + model.to(torch_device) + + tokenizer = AutoTokenizer.from_pretrained(model_id) + inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device) + + output = model.generate(**inputs, max_new_tokens=20, do_sample=False) + output_text = tokenizer.batch_decode(output, skip_special_tokens=True) + + self.assertEqual(output_text, EXPECTED_TEXTS) + + @require_torch_sdpa + def test_model_2b_sdpa(self): + model_id = "google/gemma-2b" + EXPECTED_TEXTS = [ + "Hello I am doing a project on the 1990s and I need to know what the most popular music", + "Hi today I am going to share with you a very easy and simple recipe of Khichdi", + ] + + model = AutoModelForCausalLM.from_pretrained( + model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="sdpa" + ) + model.to(torch_device) + + tokenizer = AutoTokenizer.from_pretrained(model_id) + inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device) + + output = model.generate(**inputs, max_new_tokens=20, do_sample=False) + output_text = tokenizer.batch_decode(output, skip_special_tokens=True) + + self.assertEqual(output_text, EXPECTED_TEXTS) + + @pytest.mark.flash_attn_test + @require_flash_attn + def test_model_2b_flash_attn(self): + model_id = "google/gemma-2b" + EXPECTED_TEXTS = [ + "Hello I am doing a project on the 1990s and I need to know what the most popular music", + "Hi today I am going to share with you a very easy and simple recipe of Kaju Kat", + ] + + model = AutoModelForCausalLM.from_pretrained( + model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" + ) + model.to(torch_device) + + tokenizer = AutoTokenizer.from_pretrained(model_id) + inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device) + + output = model.generate(**inputs, max_new_tokens=20, do_sample=False) + output_text = tokenizer.batch_decode(output, skip_special_tokens=True) + + self.assertEqual(output_text, EXPECTED_TEXTS) + @require_bitsandbytes def test_model_2b_4bit(self): model_id = "google/gemma-2b"