transformers/tests/models/mistral/test_modeling_mistral.py

851 lines
36 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# coding=utf-8
# Copyright 2023 Mistral AI and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch Mistral model."""
import gc
import tempfile
import unittest
import pytest
from packaging import version
from transformers import AutoTokenizer, MistralConfig, is_torch_available, set_seed
from transformers.testing_utils import (
backend_empty_cache,
is_flaky,
require_bitsandbytes,
require_flash_attn,
require_torch,
require_torch_gpu,
require_torch_sdpa,
slow,
torch_device,
)
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
import torch
from transformers import (
MistralForCausalLM,
MistralForSequenceClassification,
MistralForTokenClassification,
MistralModel,
)
class MistralModelTester:
def __init__(
self,
parent,
batch_size=13,
seq_length=7,
is_training=True,
use_input_mask=True,
use_token_type_ids=False,
use_labels=True,
vocab_size=99,
hidden_size=32,
num_hidden_layers=2,
num_attention_heads=4,
num_key_value_heads=2,
intermediate_size=37,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
type_sequence_label_size=2,
initializer_range=0.02,
num_labels=3,
num_choices=4,
pad_token_id=0,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_input_mask = use_input_mask
self.use_token_type_ids = use_token_type_ids
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.num_labels = num_labels
self.num_choices = num_choices
self.pad_token_id = pad_token_id
self.scope = scope
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None
if self.use_input_mask:
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
token_type_ids = None
if self.use_token_type_ids:
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
sequence_labels = None
token_labels = None
choice_labels = None
if self.use_labels:
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = self.get_config()
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def get_config(self):
return MistralConfig(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
num_key_value_heads=self.num_key_value_heads,
intermediate_size=self.intermediate_size,
hidden_act=self.hidden_act,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size,
is_decoder=False,
initializer_range=self.initializer_range,
pad_token_id=self.pad_token_id,
)
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model with Llama->Mistral
def create_and_check_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = MistralModel(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask)
result = model(input_ids)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model_as_decoder with Llama->Mistral
def create_and_check_model_as_decoder(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
):
config.add_cross_attention = True
model = MistralModel(config)
model.to(torch_device)
model.eval()
result = model(
input_ids,
attention_mask=input_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
)
result = model(
input_ids,
attention_mask=input_mask,
encoder_hidden_states=encoder_hidden_states,
)
result = model(input_ids, attention_mask=input_mask)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_for_causal_lm with Llama->Mistral
def create_and_check_for_causal_lm(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
):
model = MistralForCausalLM(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_decoder_model_past_large_inputs with Llama->Mistral
def create_and_check_decoder_model_past_large_inputs(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
):
config.is_decoder = True
config.add_cross_attention = True
model = MistralForCausalLM(config=config)
model.to(torch_device)
model.eval()
# first forward pass
outputs = model(
input_ids,
attention_mask=input_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=True,
)
past_key_values = outputs.past_key_values
# create hypothetical multiple next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)
# append to next input_ids and
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
next_attention_mask = torch.cat([input_mask, next_mask], dim=-1)
output_from_no_past = model(
next_input_ids,
attention_mask=next_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_hidden_states=True,
)["hidden_states"][0]
output_from_past = model(
next_tokens,
attention_mask=next_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
output_hidden_states=True,
)["hidden_states"][0]
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
# test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs_for_common
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
) = config_and_inputs
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
return config, inputs_dict
@require_torch
class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (
(MistralModel, MistralForCausalLM, MistralForSequenceClassification, MistralForTokenClassification)
if is_torch_available()
else ()
)
all_generative_model_classes = (MistralForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = (
{
"feature-extraction": MistralModel,
"text-classification": MistralForSequenceClassification,
"token-classification": MistralForTokenClassification,
"text-generation": MistralForCausalLM,
"zero-shot": MistralForSequenceClassification,
}
if is_torch_available()
else {}
)
test_headmasking = False
test_pruning = False
fx_compatible = True
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
):
return True
# TODO: @Fxmarty
@is_flaky(max_attempts=3, description="flaky on some models.")
@require_torch_sdpa
@slow
def test_eager_matches_sdpa_generate(self):
super().test_eager_matches_sdpa_generate()
def setUp(self):
self.model_tester = MistralModelTester(self)
self.config_tester = ConfigTester(self, config_class=MistralConfig, hidden_size=37)
def test_config(self):
self.config_tester.run_common_tests()
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_model_various_embeddings(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
for type in ["absolute", "relative_key", "relative_key_query"]:
config_and_inputs[0].position_embedding_type = type
self.model_tester.create_and_check_model(*config_and_inputs)
def test_Mistral_sequence_classification_model(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
print(config)
config.num_labels = 3
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
model = MistralForSequenceClassification(config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
def test_Mistral_sequence_classification_model_for_single_label(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.num_labels = 3
config.problem_type = "single_label_classification"
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
model = MistralForSequenceClassification(config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
def test_Mistral_sequence_classification_model_for_multi_label(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.num_labels = 3
config.problem_type = "multi_label_classification"
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)
sequence_labels = ids_tensor(
[self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size
).to(torch.float)
model = MistralForSequenceClassification(config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Mistral,llama->Mistral
def test_Mistral_token_classification_model(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.num_labels = 3
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
model = MistralForTokenClassification(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
self.assertEqual(
result.logits.shape,
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
)
@unittest.skip("Mistral buffers include complex numbers, which breaks this test")
def test_save_load_fast_init_from_base(self):
pass
@unittest.skip("Mistral uses GQA on all models so the KV cache is a non standard format")
def test_past_key_values_format(self):
pass
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@slow
def test_flash_attn_2_generate_padding_right(self):
import torch
for model_class in self.all_generative_model_classes:
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
torch_device
)
dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device)
model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False)
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
).to(torch_device)
with self.assertRaises(ValueError):
_ = model.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
)
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@slow
def test_flash_attn_2_generate_use_cache(self):
import torch
max_new_tokens = 30
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
dummy_input = inputs_dict[model_class.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)
# make sure that all models have enough positions for generation
if hasattr(config, "max_position_embeddings"):
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
# NOTE: Mistral apparently does not support right padding + use_cache with FA2.
dummy_attention_mask[:, -1] = 1
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
).to(torch_device)
# Just test that a large cache works as expected
_ = model.generate(
dummy_input,
attention_mask=dummy_attention_mask,
max_new_tokens=max_new_tokens,
do_sample=False,
use_cache=True,
)
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@slow
def test_flash_attn_2_inference_equivalence_right_padding(self):
self.skipTest("Mistral flash attention does not support right padding")
@require_torch_gpu
class MistralIntegrationTest(unittest.TestCase):
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
# Depending on the hardware we get different logits / generations
cuda_compute_capability_major_version = None
@classmethod
def setUpClass(cls):
if is_torch_available() and torch.cuda.is_available():
# 8 is for A100 / A10 and 7 for T4
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
def tearDown(self):
torch.cuda.empty_cache()
gc.collect()
@slow
def test_model_7b_logits(self):
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338]
model = MistralForCausalLM.from_pretrained(
"mistralai/Mistral-7B-v0.1", device_map="auto", torch_dtype=torch.float16
)
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
with torch.no_grad():
out = model(input_ids).logits.cpu()
# Expected mean on dim = -1
EXPECTED_MEAN = torch.tensor([[-2.5548, -2.5737, -3.0600, -2.5906, -2.8478, -2.8118, -2.9325, -2.7694]])
torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, atol=1e-2, rtol=1e-2)
# Key 9 for MI300, Key 8 for A100/A10, and Key 7 for T4.
#
# Note: Key 9 is currently set for MI300, but may need potential future adjustments for H100s,
# considering differences in hardware processing and potential deviations in output.
EXPECTED_SLICE = {
7: torch.tensor([-5.8781, -5.8616, -0.1052, -4.7200, -5.8781, -5.8774, -5.8773, -5.8777, -5.8781, -5.8780, -5.8781, -5.8779, -1.0787, 1.7583, -5.8779, -5.8780, -5.8783, -5.8778, -5.8776, -5.8781, -5.8784, -5.8778, -5.8778, -5.8777, -5.8779, -5.8778, -5.8776, -5.8780, -5.8779, -5.8781]),
8: torch.tensor([-5.8711, -5.8555, -0.1050, -4.7148, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -1.0781, 1.7568, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711]),
9: torch.tensor([-5.8750, -5.8594, -0.1047, -4.7188, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -1.0781, 1.7578, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750]),
} # fmt: skip
torch.testing.assert_close(
out[0, 0, :30], EXPECTED_SLICE[self.cuda_compute_capability_major_version], atol=1e-4, rtol=1e-4
)
del model
backend_empty_cache(torch_device)
gc.collect()
@slow
@require_bitsandbytes
def test_model_7b_generation(self):
EXPECTED_TEXT_COMPLETION = {
7: "My favourite condiment is 100% ketchup. I love it on everything. I'm not a big",
8: "My favourite condiment is 100% ketchup. Im not a fan of mustard, mayo,",
}
prompt = "My favourite condiment is "
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=False)
model = MistralForCausalLM.from_pretrained(
"mistralai/Mistral-7B-v0.1", device_map={"": torch_device}, load_in_4bit=True
)
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.model.embed_tokens.weight.device)
# greedy generation outputs
generated_ids = model.generate(input_ids, max_new_tokens=20, temperature=0)
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], text)
del model
backend_empty_cache(torch_device)
gc.collect()
@require_bitsandbytes
@slow
@require_flash_attn
def test_model_7b_long_prompt(self):
EXPECTED_OUTPUT_TOKEN_IDS = [306, 338]
# An input with 4097 tokens that is above the size of the sliding window
input_ids = [1] + [306, 338] * 2048
model = MistralForCausalLM.from_pretrained(
"mistralai/Mistral-7B-v0.1",
device_map={"": torch_device},
load_in_4bit=True,
attn_implementation="flash_attention_2",
)
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0)
self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist())
# Assisted generation
assistant_model = model
assistant_model.generation_config.num_assistant_tokens = 2
assistant_model.generation_config.num_assistant_tokens_schedule = "constant"
generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0)
self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist())
del assistant_model
del model
backend_empty_cache(torch_device)
gc.collect()
@slow
@require_torch_sdpa
def test_model_7b_long_prompt_sdpa(self):
EXPECTED_OUTPUT_TOKEN_IDS = [306, 338]
# An input with 4097 tokens that is above the size of the sliding window
input_ids = [1] + [306, 338] * 2048
model = MistralForCausalLM.from_pretrained(
"mistralai/Mistral-7B-v0.1", device_map="auto", attn_implementation="sdpa", torch_dtype=torch.float16
)
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0)
self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist())
# Assisted generation
assistant_model = model
assistant_model.generation_config.num_assistant_tokens = 2
assistant_model.generation_config.num_assistant_tokens_schedule = "constant"
generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0)
self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist())
del assistant_model
backend_empty_cache(torch_device)
gc.collect()
EXPECTED_TEXT_COMPLETION = """My favourite condiment is 100% ketchup. I love it on everything. Im not a big"""
prompt = "My favourite condiment is "
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=False)
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.model.embed_tokens.weight.device)
# greedy generation outputs
generated_ids = model.generate(input_ids, max_new_tokens=20, temperature=0)
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
@slow
def test_speculative_generation(self):
# Key 9 for MI300, Key 8 for A100/A10, and Key 7 for T4.
#
# Note: Key 9 is currently set for MI300, but may need potential future adjustments for H100s,
# considering differences in hardware processing and potential deviations in generated text.
EXPECTED_TEXT_COMPLETION = {
7: "My favourite condiment is 100% Sriracha. I love the heat, the tang and the fact costs",
8: "My favourite condiment is 100% ketchup. I love it on everything. Im not a big",
9: "My favourite condiment is 100% ketchup. I love it on everything. Im not a big",
}
prompt = "My favourite condiment is "
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=False)
model = MistralForCausalLM.from_pretrained(
"mistralai/Mistral-7B-v0.1", device_map="auto", torch_dtype=torch.float16
)
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.model.embed_tokens.weight.device)
# greedy generation outputs
set_seed(0)
generated_ids = model.generate(
input_ids, max_new_tokens=20, do_sample=True, temperature=0.3, assistant_model=model
)
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], text)
del model
backend_empty_cache(torch_device)
gc.collect()
@slow
def test_compile_static_cache(self):
# `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2
# work as intended. See https://github.com/pytorch/pytorch/issues/121943
if version.parse(torch.__version__) < version.parse("2.3.0"):
self.skipTest("This test requires torch >= 2.3 to run.")
NUM_TOKENS_TO_GENERATE = 40
EXPECTED_TEXT_COMPLETION = {
8: [
"My favourite condiment is 100% ketchup. I love it on everything. "
"Im not a big fan of mustard, mayo, or relish. Im not a fan of pickles"
],
7: [
"My favourite condiment is 100% ketchup. I love it on everything. "
"Im not a big fan of mustard, mayo, or relish. Im not a fan of pickles"
],
}
prompts = ["My favourite condiment is "]
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=False)
tokenizer.pad_token = tokenizer.eos_token
model = MistralForCausalLM.from_pretrained(
"mistralai/Mistral-7B-v0.1", device_map="sequential", torch_dtype=torch.float16
)
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
# Dynamic Cache
generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False)
dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], dynamic_text)
# Static Cache
generated_ids = model.generate(
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
)
static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_text)
# Sliding Window Cache
generated_ids = model.generate(
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="sliding_window"
)
static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_text)
# Static Cache + compile
forward_function = model.forward
model.forward = torch.compile(forward_function, mode="reduce-overhead", fullgraph=True)
generated_ids = model.generate(
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
)
static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_compiled_text)
# Sliding Window Cache + compile
torch._dynamo.reset()
model.forward = torch.compile(forward_function, mode="reduce-overhead", fullgraph=True)
generated_ids = model.generate(
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="sliding_window"
)
static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_compiled_text)
del model
backend_empty_cache(torch_device)
gc.collect()
@slow
@require_torch_gpu
class Mask4DTestHard(unittest.TestCase):
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
def setUp(self):
model_name = "mistralai/Mistral-7B-v0.1"
self.model_dtype = torch.float32
self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
self.model = MistralForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device)
def get_test_data(self):
template = "my favorite {}"
items = ("pet is a", "artist plays a", "name is L") # same number of tokens in each item
batch_separate = [template.format(x) for x in items] # 3 separate lines
batch_shared_prefix = template.format(" ".join(items)) # 1 line with options concatenated
input_ids = self.tokenizer(batch_separate, return_tensors="pt").input_ids.to(torch_device)
input_ids_shared_prefix = self.tokenizer(batch_shared_prefix, return_tensors="pt").input_ids.to(torch_device)
mask_shared_prefix = torch.tensor(
[
[
[
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0],
[1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0],
[1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1],
]
]
],
device=torch_device,
)
position_ids = torch.arange(input_ids.shape[1]).tile(input_ids.shape[0], 1).to(torch_device)
# building custom positions ids based on custom mask
position_ids_shared_prefix = (mask_shared_prefix.sum(dim=-1) - 1).reshape(1, -1)
# effectively: position_ids_shared_prefix = torch.tensor([[0, 1, 2, 3, 4, 5, 3, 4, 5, 3, 4, 5]]).to(device)
# inverting the mask
min_dtype = torch.finfo(self.model_dtype).min
mask_shared_prefix = (mask_shared_prefix.eq(0.0)).to(dtype=self.model_dtype) * min_dtype
return input_ids, position_ids, input_ids_shared_prefix, mask_shared_prefix, position_ids_shared_prefix
def test_stacked_causal_mask(self):
(
input_ids,
position_ids,
input_ids_shared_prefix,
mask_shared_prefix,
position_ids_shared_prefix,
) = self.get_test_data()
# regular batch
logits = self.model.forward(input_ids, position_ids=position_ids).logits
logits_last = logits[:, -1, :] # last tokens in each batch line
decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)]
# single forward run with 4D custom mask
logits_shared_prefix = self.model.forward(
input_ids_shared_prefix, attention_mask=mask_shared_prefix, position_ids=position_ids_shared_prefix
).logits
logits_shared_prefix_last = logits_shared_prefix[
0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1], :
] # last three tokens
decoded_shared_prefix = [self.tokenizer.decode(t) for t in logits_shared_prefix_last.argmax(dim=-1)]
self.assertEqual(decoded, decoded_shared_prefix)
def test_partial_stacked_causal_mask(self):
# Same as the test above, but the input is passed in two groups. It tests that we can pass partial 4D attention masks
(
input_ids,
position_ids,
input_ids_shared_prefix,
mask_shared_prefix,
position_ids_shared_prefix,
) = self.get_test_data()
# regular batch
logits = self.model.forward(input_ids, position_ids=position_ids).logits
logits_last = logits[:, -1, :] # last tokens in each batch line
decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)]
# 2 forward runs with custom 4D masks
part_a = 3 # split point
input_1a = input_ids_shared_prefix[:, :part_a]
position_ids_1a = position_ids_shared_prefix[:, :part_a]
mask_1a = mask_shared_prefix[:, :, :part_a, :part_a]
outs_1a = self.model.forward(input_1a, attention_mask=mask_1a, position_ids=position_ids_1a)
past_key_values_a = outs_1a["past_key_values"]
# Case 1: we pass a 4D attention mask regarding the current sequence length (i.e. [..., seq_len, full_len])
input_1b = input_ids_shared_prefix[:, part_a:]
position_ids_1b = position_ids_shared_prefix[:, part_a:]
mask_1b = mask_shared_prefix[:, :, part_a:, :]
outs_1b = self.model.forward(
input_1b, attention_mask=mask_1b, position_ids=position_ids_1b, past_key_values=past_key_values_a
)
decoded_1b = [
self.tokenizer.decode(t)
for t in outs_1b.logits.argmax(-1)[
0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1] - part_a
]
]
self.assertEqual(decoded, decoded_1b)