991 lines
41 KiB
Python
991 lines
41 KiB
Python
# coding=utf-8
|
|
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Testing suite for the PyTorch LLaMA model."""
|
|
|
|
import gc
|
|
import tempfile
|
|
import unittest
|
|
|
|
import pytest
|
|
from packaging import version
|
|
from parameterized import parameterized
|
|
|
|
from transformers import LlamaConfig, StaticCache, is_torch_available, set_seed
|
|
from transformers.testing_utils import (
|
|
require_bitsandbytes,
|
|
require_flash_attn,
|
|
require_read_token,
|
|
require_torch,
|
|
require_torch_gpu,
|
|
require_torch_sdpa,
|
|
slow,
|
|
torch_device,
|
|
)
|
|
|
|
from ...generation.test_utils import GenerationTesterMixin
|
|
from ...test_configuration_common import ConfigTester
|
|
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
|
from ...test_pipeline_mixin import PipelineTesterMixin
|
|
|
|
|
|
if is_torch_available():
|
|
import torch
|
|
|
|
from transformers import (
|
|
LlamaForCausalLM,
|
|
LlamaForQuestionAnswering,
|
|
LlamaForSequenceClassification,
|
|
LlamaModel,
|
|
LlamaTokenizer,
|
|
)
|
|
from transformers.models.llama.modeling_llama import (
|
|
LlamaDynamicNTKScalingRotaryEmbedding,
|
|
LlamaLinearScalingRotaryEmbedding,
|
|
LlamaRotaryEmbedding,
|
|
)
|
|
|
|
|
|
class LlamaModelTester:
|
|
def __init__(
|
|
self,
|
|
parent,
|
|
batch_size=13,
|
|
seq_length=7,
|
|
is_training=True,
|
|
use_input_mask=True,
|
|
use_token_type_ids=False,
|
|
use_labels=True,
|
|
vocab_size=99,
|
|
hidden_size=32,
|
|
num_hidden_layers=2,
|
|
num_attention_heads=4,
|
|
intermediate_size=37,
|
|
hidden_act="gelu",
|
|
hidden_dropout_prob=0.1,
|
|
attention_probs_dropout_prob=0.1,
|
|
max_position_embeddings=512,
|
|
type_vocab_size=16,
|
|
type_sequence_label_size=2,
|
|
initializer_range=0.02,
|
|
num_labels=3,
|
|
num_choices=4,
|
|
pad_token_id=0,
|
|
scope=None,
|
|
):
|
|
self.parent = parent
|
|
self.batch_size = batch_size
|
|
self.seq_length = seq_length
|
|
self.is_training = is_training
|
|
self.use_input_mask = use_input_mask
|
|
self.use_token_type_ids = use_token_type_ids
|
|
self.use_labels = use_labels
|
|
self.vocab_size = vocab_size
|
|
self.hidden_size = hidden_size
|
|
self.num_hidden_layers = num_hidden_layers
|
|
self.num_attention_heads = num_attention_heads
|
|
self.intermediate_size = intermediate_size
|
|
self.hidden_act = hidden_act
|
|
self.hidden_dropout_prob = hidden_dropout_prob
|
|
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
|
self.max_position_embeddings = max_position_embeddings
|
|
self.type_vocab_size = type_vocab_size
|
|
self.type_sequence_label_size = type_sequence_label_size
|
|
self.initializer_range = initializer_range
|
|
self.num_labels = num_labels
|
|
self.num_choices = num_choices
|
|
self.pad_token_id = pad_token_id
|
|
self.scope = scope
|
|
|
|
def prepare_config_and_inputs(self):
|
|
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
|
|
|
input_mask = None
|
|
if self.use_input_mask:
|
|
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
|
|
|
|
token_type_ids = None
|
|
if self.use_token_type_ids:
|
|
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
|
|
|
sequence_labels = None
|
|
token_labels = None
|
|
choice_labels = None
|
|
if self.use_labels:
|
|
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
|
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
|
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
|
|
|
config = self.get_config()
|
|
|
|
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
|
|
def get_config(self):
|
|
return LlamaConfig(
|
|
vocab_size=self.vocab_size,
|
|
hidden_size=self.hidden_size,
|
|
num_hidden_layers=self.num_hidden_layers,
|
|
num_attention_heads=self.num_attention_heads,
|
|
intermediate_size=self.intermediate_size,
|
|
hidden_act=self.hidden_act,
|
|
hidden_dropout_prob=self.hidden_dropout_prob,
|
|
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
|
max_position_embeddings=self.max_position_embeddings,
|
|
type_vocab_size=self.type_vocab_size,
|
|
is_decoder=False,
|
|
initializer_range=self.initializer_range,
|
|
pad_token_id=self.pad_token_id,
|
|
)
|
|
|
|
def create_and_check_model(
|
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
):
|
|
model = LlamaModel(config=config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
result = model(input_ids, attention_mask=input_mask)
|
|
result = model(input_ids)
|
|
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
|
|
|
def create_and_check_model_as_decoder(
|
|
self,
|
|
config,
|
|
input_ids,
|
|
token_type_ids,
|
|
input_mask,
|
|
sequence_labels,
|
|
token_labels,
|
|
choice_labels,
|
|
encoder_hidden_states,
|
|
encoder_attention_mask,
|
|
):
|
|
config.add_cross_attention = True
|
|
model = LlamaModel(config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
result = model(
|
|
input_ids,
|
|
attention_mask=input_mask,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
)
|
|
result = model(
|
|
input_ids,
|
|
attention_mask=input_mask,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
)
|
|
result = model(input_ids, attention_mask=input_mask)
|
|
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
|
|
|
def create_and_check_for_causal_lm(
|
|
self,
|
|
config,
|
|
input_ids,
|
|
token_type_ids,
|
|
input_mask,
|
|
sequence_labels,
|
|
token_labels,
|
|
choice_labels,
|
|
encoder_hidden_states,
|
|
encoder_attention_mask,
|
|
):
|
|
model = LlamaForCausalLM(config=config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
|
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
|
|
|
def create_and_check_decoder_model_past_large_inputs(
|
|
self,
|
|
config,
|
|
input_ids,
|
|
token_type_ids,
|
|
input_mask,
|
|
sequence_labels,
|
|
token_labels,
|
|
choice_labels,
|
|
encoder_hidden_states,
|
|
encoder_attention_mask,
|
|
):
|
|
config.is_decoder = True
|
|
config.add_cross_attention = True
|
|
model = LlamaForCausalLM(config=config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
|
|
# first forward pass
|
|
outputs = model(
|
|
input_ids,
|
|
attention_mask=input_mask,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
use_cache=True,
|
|
)
|
|
past_key_values = outputs.past_key_values
|
|
|
|
# create hypothetical multiple next token and extent to next_input_ids
|
|
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
|
|
next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)
|
|
|
|
# append to next input_ids and
|
|
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
|
next_attention_mask = torch.cat([input_mask, next_mask], dim=-1)
|
|
|
|
output_from_no_past = model(
|
|
next_input_ids,
|
|
attention_mask=next_attention_mask,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
output_hidden_states=True,
|
|
)["hidden_states"][0]
|
|
output_from_past = model(
|
|
next_tokens,
|
|
attention_mask=next_attention_mask,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
past_key_values=past_key_values,
|
|
output_hidden_states=True,
|
|
)["hidden_states"][0]
|
|
|
|
# select random slice
|
|
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
|
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
|
|
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
|
|
|
|
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
|
|
|
|
# test that outputs are equal for slice
|
|
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
|
|
|
def prepare_config_and_inputs_for_common(self):
|
|
config_and_inputs = self.prepare_config_and_inputs()
|
|
(
|
|
config,
|
|
input_ids,
|
|
token_type_ids,
|
|
input_mask,
|
|
sequence_labels,
|
|
token_labels,
|
|
choice_labels,
|
|
) = config_and_inputs
|
|
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
|
return config, inputs_dict
|
|
|
|
|
|
@require_torch
|
|
class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|
all_model_classes = (
|
|
(LlamaModel, LlamaForCausalLM, LlamaForSequenceClassification, LlamaForQuestionAnswering)
|
|
if is_torch_available()
|
|
else ()
|
|
)
|
|
all_generative_model_classes = (LlamaForCausalLM,) if is_torch_available() else ()
|
|
pipeline_model_mapping = (
|
|
{
|
|
"feature-extraction": LlamaModel,
|
|
"text-classification": LlamaForSequenceClassification,
|
|
"text-generation": LlamaForCausalLM,
|
|
"zero-shot": LlamaForSequenceClassification,
|
|
"question-answering": LlamaForQuestionAnswering,
|
|
}
|
|
if is_torch_available()
|
|
else {}
|
|
)
|
|
test_headmasking = False
|
|
test_pruning = False
|
|
fx_compatible = True
|
|
|
|
# Need to use `0.8` instead of `0.9` for `test_cpu_offload`
|
|
# This is because we are hitting edge cases with the causal_mask buffer
|
|
model_split_percents = [0.5, 0.7, 0.8]
|
|
|
|
# used in `test_torch_compile`
|
|
_torch_compile_test_ckpt = "meta-llama/Llama-2-7b-hf"
|
|
|
|
def setUp(self):
|
|
self.model_tester = LlamaModelTester(self)
|
|
self.config_tester = ConfigTester(self, config_class=LlamaConfig, hidden_size=37)
|
|
|
|
def test_config(self):
|
|
self.config_tester.run_common_tests()
|
|
|
|
def test_model(self):
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
|
|
|
def test_model_various_embeddings(self):
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
for type in ["absolute", "relative_key", "relative_key_query"]:
|
|
config_and_inputs[0].position_embedding_type = type
|
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
|
|
|
def test_llama_sequence_classification_model(self):
|
|
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
config.num_labels = 3
|
|
input_ids = input_dict["input_ids"]
|
|
attention_mask = input_ids.ne(1).to(torch_device)
|
|
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
|
|
model = LlamaForSequenceClassification(config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
|
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
|
|
|
def test_llama_sequence_classification_model_for_single_label(self):
|
|
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
config.num_labels = 3
|
|
config.problem_type = "single_label_classification"
|
|
input_ids = input_dict["input_ids"]
|
|
attention_mask = input_ids.ne(1).to(torch_device)
|
|
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
|
|
model = LlamaForSequenceClassification(config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
|
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
|
|
|
def test_llama_sequence_classification_model_for_multi_label(self):
|
|
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
config.num_labels = 3
|
|
config.problem_type = "multi_label_classification"
|
|
input_ids = input_dict["input_ids"]
|
|
attention_mask = input_ids.ne(1).to(torch_device)
|
|
sequence_labels = ids_tensor(
|
|
[self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size
|
|
).to(torch.float)
|
|
model = LlamaForSequenceClassification(config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
|
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
|
|
|
@unittest.skip("Llama buffers include complex numbers, which breaks this test")
|
|
def test_save_load_fast_init_from_base(self):
|
|
pass
|
|
|
|
@parameterized.expand([("linear",), ("dynamic",)])
|
|
def test_model_rope_scaling_from_config(self, scaling_type):
|
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
|
short_input = ids_tensor([1, 10], config.vocab_size)
|
|
long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size)
|
|
|
|
set_seed(42) # Fixed seed at init time so the two models get the same random weights
|
|
original_model = LlamaModel(config)
|
|
original_model.to(torch_device)
|
|
original_model.eval()
|
|
original_short_output = original_model(short_input).last_hidden_state
|
|
original_long_output = original_model(long_input).last_hidden_state
|
|
|
|
set_seed(42) # Fixed seed at init time so the two models get the same random weights
|
|
config.rope_scaling = {"type": scaling_type, "factor": 10.0}
|
|
scaled_model = LlamaModel(config)
|
|
scaled_model.to(torch_device)
|
|
scaled_model.eval()
|
|
scaled_short_output = scaled_model(short_input).last_hidden_state
|
|
scaled_long_output = scaled_model(long_input).last_hidden_state
|
|
|
|
# Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original
|
|
# maximum sequence length, so the outputs for the short input should match.
|
|
if scaling_type == "dynamic":
|
|
self.assertTrue(torch.allclose(original_short_output, scaled_short_output, atol=1e-5))
|
|
else:
|
|
self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5))
|
|
|
|
# The output should be different for long inputs
|
|
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
|
|
|
|
def test_model_rope_scaling(self):
|
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
|
hidden_size = config.hidden_size
|
|
num_heads = config.num_attention_heads
|
|
head_dim = hidden_size // num_heads
|
|
scaling_factor = 10
|
|
short_input_length = 10
|
|
long_input_length = int(config.max_position_embeddings * 1.5)
|
|
|
|
# Inputs
|
|
x = torch.randn(1, dtype=torch.float32, device=torch_device) # used exlusively to get the dtype and the device
|
|
position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device)
|
|
position_ids_short = position_ids_short.unsqueeze(0)
|
|
position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device)
|
|
position_ids_long = position_ids_long.unsqueeze(0)
|
|
|
|
# Sanity check original RoPE
|
|
original_rope = LlamaRotaryEmbedding(
|
|
head_dim,
|
|
max_position_embeddings=config.max_position_embeddings,
|
|
base=config.rope_theta,
|
|
).to(torch_device)
|
|
original_cos_short, original_sin_short = original_rope(x, position_ids_short)
|
|
original_cos_long, original_sin_long = original_rope(x, position_ids_long)
|
|
torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :])
|
|
torch.testing.assert_close(original_sin_short, original_sin_long[:, :short_input_length, :])
|
|
|
|
# Sanity check linear RoPE scaling
|
|
# New position "x" should match original position with index "x/scaling_factor"
|
|
linear_scaling_rope = LlamaLinearScalingRotaryEmbedding(
|
|
head_dim,
|
|
max_position_embeddings=config.max_position_embeddings,
|
|
base=config.rope_theta,
|
|
scaling_factor=scaling_factor,
|
|
).to(torch_device)
|
|
linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short)
|
|
linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long)
|
|
torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :])
|
|
torch.testing.assert_close(linear_sin_short, linear_sin_long[:, :short_input_length, :])
|
|
for new_position in range(0, long_input_length, scaling_factor):
|
|
original_position = int(new_position // scaling_factor)
|
|
torch.testing.assert_close(linear_cos_long[:, new_position, :], original_cos_long[:, original_position, :])
|
|
torch.testing.assert_close(linear_sin_long[:, new_position, :], original_sin_long[:, original_position, :])
|
|
|
|
# Sanity check Dynamic NTK RoPE scaling
|
|
# Scaling should only be observed after a long input is fed. We can observe that the frequencies increase
|
|
# with scaling_factor (or that `inv_freq` decreases)
|
|
ntk_scaling_rope = LlamaDynamicNTKScalingRotaryEmbedding(
|
|
head_dim,
|
|
max_position_embeddings=config.max_position_embeddings,
|
|
base=config.rope_theta,
|
|
scaling_factor=scaling_factor,
|
|
).to(torch_device)
|
|
ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short)
|
|
ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long)
|
|
torch.testing.assert_close(ntk_cos_short, original_cos_short)
|
|
torch.testing.assert_close(ntk_sin_short, original_sin_short)
|
|
with self.assertRaises(AssertionError):
|
|
torch.testing.assert_close(ntk_cos_long, original_cos_long)
|
|
with self.assertRaises(AssertionError):
|
|
torch.testing.assert_close(ntk_sin_long, original_sin_long)
|
|
self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all())
|
|
|
|
@require_flash_attn
|
|
@require_torch_gpu
|
|
@require_bitsandbytes
|
|
@pytest.mark.flash_attn_test
|
|
@require_read_token
|
|
@slow
|
|
def test_flash_attn_2_generate_padding_right(self):
|
|
"""
|
|
Overwritting the common test as the test is flaky on tiny models
|
|
"""
|
|
model = LlamaForCausalLM.from_pretrained(
|
|
"meta-llama/Llama-2-7b-hf",
|
|
load_in_4bit=True,
|
|
device_map={"": 0},
|
|
)
|
|
|
|
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
|
|
|
|
texts = ["hi", "Hello this is a very long sentence"]
|
|
|
|
tokenizer.padding_side = "right"
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
inputs = tokenizer(texts, return_tensors="pt", padding=True).to(0)
|
|
|
|
output_native = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
|
output_native = tokenizer.batch_decode(output_native)
|
|
|
|
model = LlamaForCausalLM.from_pretrained(
|
|
"meta-llama/Llama-2-7b-hf", load_in_4bit=True, device_map={"": 0}, attn_implementation="flash_attention_2"
|
|
)
|
|
|
|
output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
|
output_fa_2 = tokenizer.batch_decode(output_fa_2)
|
|
|
|
self.assertListEqual(output_native, output_fa_2)
|
|
|
|
@require_flash_attn
|
|
@require_torch_gpu
|
|
@slow
|
|
def test_use_flash_attention_2_true(self):
|
|
"""
|
|
NOTE: this is the only test testing that the legacy `use_flash_attention=2` argument still works as intended.
|
|
"""
|
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
for model_class in self.all_model_classes:
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
model = model_class(config)
|
|
model.save_pretrained(tmp_dir)
|
|
|
|
new_model = LlamaForCausalLM.from_pretrained(
|
|
tmp_dir, use_flash_attention_2=True, torch_dtype=torch.float16
|
|
).to("cuda")
|
|
|
|
self.assertTrue(new_model.config._attn_implementation == "flash_attention_2")
|
|
|
|
has_flash = False
|
|
for name, submodule in new_model.named_modules():
|
|
if "FlashAttention" in submodule.__class__.__name__:
|
|
has_flash = True
|
|
break
|
|
if not has_flash:
|
|
raise ValueError("The flash model should have flash attention layers")
|
|
|
|
@require_torch_sdpa
|
|
@slow
|
|
def test_eager_matches_sdpa_generate(self):
|
|
"""
|
|
Overwritting the common test as the test is flaky on tiny models
|
|
"""
|
|
max_new_tokens = 30
|
|
|
|
tokenizer = LlamaTokenizer.from_pretrained("saibo/llama-1B")
|
|
|
|
model_sdpa = LlamaForCausalLM.from_pretrained(
|
|
"saibo/llama-1B",
|
|
torch_dtype=torch.float16,
|
|
low_cpu_mem_usage=True,
|
|
).to(torch_device)
|
|
|
|
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
|
|
|
|
model_eager = LlamaForCausalLM.from_pretrained(
|
|
"saibo/llama-1B",
|
|
torch_dtype=torch.float16,
|
|
low_cpu_mem_usage=True,
|
|
attn_implementation="eager",
|
|
).to(torch_device)
|
|
|
|
self.assertTrue(model_eager.config._attn_implementation == "eager")
|
|
|
|
for name, submodule in model_eager.named_modules():
|
|
if "SdpaAttention" in submodule.__class__.__name__:
|
|
raise ValueError("The eager model should not have SDPA attention layers")
|
|
|
|
has_sdpa = False
|
|
for name, submodule in model_sdpa.named_modules():
|
|
if "SdpaAttention" in submodule.__class__.__name__:
|
|
has_sdpa = True
|
|
break
|
|
if not has_sdpa:
|
|
raise ValueError("The SDPA model should have SDPA attention layers")
|
|
|
|
texts = [
|
|
"hi here's a longer context, getting longer and",
|
|
"Hello this is a very long sentence my friend, very long for real",
|
|
"Today I am in Paris and",
|
|
]
|
|
|
|
for padding_side in ["left", "right"]:
|
|
tokenizer.padding_side = padding_side
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
inputs = tokenizer(texts, return_tensors="pt", padding=True).to(torch_device)
|
|
|
|
res_eager = model_eager.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
|
|
res_sdpa = model_sdpa.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
|
|
|
|
with self.subTest(f"{padding_side}"):
|
|
torch.testing.assert_close(
|
|
res_eager,
|
|
res_sdpa,
|
|
msg=f"\n{tokenizer.batch_decode(res_eager)} \nvs\n{tokenizer.batch_decode(res_sdpa)}",
|
|
)
|
|
|
|
|
|
@require_torch_gpu
|
|
class LlamaIntegrationTest(unittest.TestCase):
|
|
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
|
|
# Depending on the hardware we get different logits / generations
|
|
cuda_compute_capability_major_version = None
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
if is_torch_available() and torch.cuda.is_available():
|
|
# 8 is for A100 / A10 and 7 for T4
|
|
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
|
|
|
|
@slow
|
|
@require_read_token
|
|
def test_model_7b_logits_bf16(self):
|
|
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338]
|
|
|
|
model = LlamaForCausalLM.from_pretrained(
|
|
"meta-llama/Llama-2-7b-hf", device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager"
|
|
)
|
|
|
|
with torch.no_grad():
|
|
out = model(torch.tensor([input_ids]).to(torch_device))
|
|
# Expected mean on dim = -1
|
|
|
|
# fmt: off
|
|
EXPECTED_MEAN = {
|
|
7: torch.tensor([[-6.5061, -4.1147, -4.9669, -3.2038, 0.8069, -2.9694, 1.2864, -3.3786]]),
|
|
8: torch.tensor([[-6.5208, -4.1218, -4.9377, -3.2536, 0.8127, -2.9811, 1.2918, -3.3848]])
|
|
}
|
|
|
|
self.assertTrue(torch.allclose(EXPECTED_MEAN[self.cuda_compute_capability_major_version].to(torch_device), out.logits.mean(-1), atol=1e-2, rtol=1e-2))
|
|
|
|
# slicing logits[0, 0, 0:15]
|
|
EXPECTED_SLICE = {
|
|
7: torch.tensor([[-12.5000, -7.0625, -0.6289, -7.8750, -6.9688, -7.8125, -6.4688, -7.4375, -7.6875, -6.9375, -6.0312, -7.0000, -1.8594, 1.8438, -8.5000]]),
|
|
8: torch.tensor([[-12.5625, -7.1250, -0.6289, -7.8750, -6.9688, -7.8125, -6.5000, -7.4375, -7.6562, -6.9688, -6.0312, -7.0312, -1.8203, 1.8750, -8.5000]])
|
|
}
|
|
# fmt: on
|
|
|
|
self.assertTrue(
|
|
torch.allclose(
|
|
EXPECTED_SLICE[self.cuda_compute_capability_major_version].to(torch_device),
|
|
out.logits[0, 0, :15],
|
|
atol=1e-3,
|
|
rtol=1e-3,
|
|
)
|
|
)
|
|
|
|
@slow
|
|
@require_read_token
|
|
def test_model_7b_logits(self):
|
|
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338]
|
|
|
|
model = LlamaForCausalLM.from_pretrained(
|
|
"meta-llama/Llama-2-7b-hf", device_map="auto", torch_dtype=torch.float16
|
|
)
|
|
|
|
with torch.no_grad():
|
|
out = model(torch.tensor([input_ids]).to(torch_device))
|
|
|
|
# fmt: off
|
|
# Expected mean on dim = -1
|
|
EXPECTED_MEAN = {
|
|
7: torch.tensor([[-6.6420, -4.1227, -4.9809, -3.2041, 0.8261, -3.0052, 1.2957, -3.3648]]),
|
|
8: torch.tensor([[-6.6544, -4.1259, -4.9840, -3.2456, 0.8261, -3.0124, 1.2971, -3.3641]])
|
|
}
|
|
|
|
self.assertTrue(torch.allclose(EXPECTED_MEAN[self.cuda_compute_capability_major_version].to(torch_device), out.logits.mean(-1), atol=1e-2, rtol=1e-2))
|
|
|
|
# slicing logits[0, 0, 0:15]
|
|
EXPECTED_SLICE = {
|
|
7: torch.tensor([-12.8125, -7.3359, -0.4846, -8.0234, -7.2383, -7.9922, -6.4805, -7.7344, -7.8125, -7.0078, -6.1797, -7.1094, -1.8633, 1.9736, -8.6016]),
|
|
8: torch.tensor([-12.8281, -7.4609, -0.4668, -8.0703, -7.2539, -8.0078, -6.4961, -7.7734, -7.8516, -7.0352, -6.2188, -7.1367, -1.8564, 1.9922, -8.6328])
|
|
}
|
|
# fmt: on
|
|
|
|
self.assertTrue(
|
|
torch.allclose(
|
|
EXPECTED_SLICE[self.cuda_compute_capability_major_version].to(torch_device),
|
|
out.logits[0, 0, :15],
|
|
atol=1e-3,
|
|
rtol=1e-3,
|
|
)
|
|
)
|
|
|
|
@slow
|
|
@require_torch_gpu
|
|
@require_read_token
|
|
def test_compile_static_cache(self):
|
|
# `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2
|
|
# work as intended. See https://github.com/pytorch/pytorch/issues/121943
|
|
if version.parse(torch.__version__) < version.parse("2.3.0"):
|
|
self.skipTest("This test requires torch >= 2.3 to run.")
|
|
|
|
NUM_TOKENS_TO_GENERATE = 40
|
|
# Note on `EXPECTED_TEXT_COMPLETION`'s diff: the current value matches the original test if the original test
|
|
# was changed to have a cache of 53 tokens (as opposed to 4096), on Ampere GPUs.
|
|
EXPECTED_TEXT_COMPLETION = {
|
|
8: [
|
|
"Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial "
|
|
"reference frames, and 2) the laws of physics are the same for all inertial reference frames.\nThe "
|
|
"theory of relativ",
|
|
"My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, "
|
|
"my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p",
|
|
],
|
|
7: [
|
|
"Simply put, the theory of relativity states that 1. surely nothing is faster than light.\nThe theory "
|
|
"goes that nothing travels faster than light, but the faster you go, the slower everything else will "
|
|
"be.\nThe theory of relativity",
|
|
"My favorite all time favorite condiment is ketchup. I love it on hamburgers, hot dogs, fries, eggs, "
|
|
"and even on a good old fashioned cheeseburger. I love it on everything. I love it so",
|
|
],
|
|
}
|
|
|
|
prompts = [
|
|
"Simply put, the theory of relativity states that ",
|
|
"My favorite all time favorite condiment is ketchup.",
|
|
]
|
|
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token="</s>", padding_side="right")
|
|
model = LlamaForCausalLM.from_pretrained(
|
|
"meta-llama/Llama-2-7b-hf", device_map="sequential", torch_dtype=torch.float16
|
|
)
|
|
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
|
|
|
|
# Dynamic Cache
|
|
generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False)
|
|
dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
|
self.assertEqual(EXPECTED_TEXT_COMPLETION[8], dynamic_text) # Both GPU architectures have the same output
|
|
|
|
# Static Cache
|
|
generated_ids = model.generate(
|
|
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
|
|
)
|
|
static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
|
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_text)
|
|
|
|
# Static Cache + compile
|
|
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
|
|
generated_ids = model.generate(
|
|
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
|
|
)
|
|
static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
|
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_compiled_text)
|
|
|
|
|
|
@slow
|
|
@require_torch_gpu
|
|
class Mask4DTestHard(unittest.TestCase):
|
|
def tearDown(self):
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
|
|
def setUp(self):
|
|
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
|
self.model_dtype = torch.float32
|
|
self.tokenizer = LlamaTokenizer.from_pretrained(model_name)
|
|
self.model = LlamaForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device)
|
|
|
|
def get_test_data(self):
|
|
template = "my favorite {}"
|
|
items = ("pet is a", "artist plays a", "name is L") # same number of tokens in each item
|
|
|
|
batch_separate = [template.format(x) for x in items] # 3 separate lines
|
|
batch_shared_prefix = template.format(" ".join(items)) # 1 line with options concatenated
|
|
|
|
input_ids = self.tokenizer(batch_separate, return_tensors="pt").input_ids.to(torch_device)
|
|
input_ids_shared_prefix = self.tokenizer(batch_shared_prefix, return_tensors="pt").input_ids.to(torch_device)
|
|
|
|
mask_shared_prefix = torch.tensor(
|
|
[
|
|
[
|
|
[
|
|
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
|
[1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
|
[1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
|
[1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
|
|
[1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
|
|
[1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
|
|
[1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0],
|
|
[1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0],
|
|
[1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0],
|
|
[1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0],
|
|
[1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0],
|
|
[1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1],
|
|
]
|
|
]
|
|
],
|
|
device=torch_device,
|
|
)
|
|
|
|
position_ids = torch.arange(input_ids.shape[1]).tile(input_ids.shape[0], 1).to(torch_device)
|
|
|
|
# building custom positions ids based on custom mask
|
|
position_ids_shared_prefix = (mask_shared_prefix.sum(dim=-1) - 1).reshape(1, -1)
|
|
# effectively: position_ids_shared_prefix = torch.tensor([[0, 1, 2, 3, 4, 5, 3, 4, 5, 3, 4, 5]]).to(device)
|
|
|
|
# inverting the mask
|
|
min_dtype = torch.finfo(self.model_dtype).min
|
|
mask_shared_prefix = (mask_shared_prefix.eq(0.0)).to(dtype=self.model_dtype) * min_dtype
|
|
|
|
return input_ids, position_ids, input_ids_shared_prefix, mask_shared_prefix, position_ids_shared_prefix
|
|
|
|
def test_stacked_causal_mask(self):
|
|
(
|
|
input_ids,
|
|
position_ids,
|
|
input_ids_shared_prefix,
|
|
mask_shared_prefix,
|
|
position_ids_shared_prefix,
|
|
) = self.get_test_data()
|
|
|
|
# regular batch
|
|
logits = self.model.forward(input_ids, position_ids=position_ids).logits
|
|
logits_last = logits[:, -1, :] # last tokens in each batch line
|
|
decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)]
|
|
|
|
# single forward run with 4D custom mask
|
|
logits_shared_prefix = self.model.forward(
|
|
input_ids_shared_prefix, attention_mask=mask_shared_prefix, position_ids=position_ids_shared_prefix
|
|
).logits
|
|
logits_shared_prefix_last = logits_shared_prefix[
|
|
0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1], :
|
|
] # last three tokens
|
|
decoded_shared_prefix = [self.tokenizer.decode(t) for t in logits_shared_prefix_last.argmax(dim=-1)]
|
|
|
|
self.assertEqual(decoded, decoded_shared_prefix)
|
|
|
|
def test_partial_stacked_causal_mask(self):
|
|
# Same as the test above, but the input is passed in two groups. It tests that we can pass partial 4D attention masks
|
|
|
|
(
|
|
input_ids,
|
|
position_ids,
|
|
input_ids_shared_prefix,
|
|
mask_shared_prefix,
|
|
position_ids_shared_prefix,
|
|
) = self.get_test_data()
|
|
|
|
# regular batch
|
|
logits = self.model.forward(input_ids, position_ids=position_ids).logits
|
|
logits_last = logits[:, -1, :] # last tokens in each batch line
|
|
decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)]
|
|
|
|
# 2 forward runs with custom 4D masks
|
|
part_a = 3 # split point
|
|
|
|
input_1a = input_ids_shared_prefix[:, :part_a]
|
|
position_ids_1a = position_ids_shared_prefix[:, :part_a]
|
|
mask_1a = mask_shared_prefix[:, :, :part_a, :part_a]
|
|
|
|
outs_1a = self.model.forward(input_1a, attention_mask=mask_1a, position_ids=position_ids_1a)
|
|
past_key_values_a = outs_1a["past_key_values"]
|
|
|
|
# Case 1: we pass a 4D attention mask regarding the current sequence length (i.e. [..., seq_len, full_len])
|
|
input_1b = input_ids_shared_prefix[:, part_a:]
|
|
position_ids_1b = position_ids_shared_prefix[:, part_a:]
|
|
mask_1b = mask_shared_prefix[:, :, part_a:, :]
|
|
outs_1b = self.model.forward(
|
|
input_1b,
|
|
attention_mask=mask_1b,
|
|
position_ids=position_ids_1b,
|
|
past_key_values=past_key_values_a,
|
|
)
|
|
decoded_1b = [
|
|
self.tokenizer.decode(t)
|
|
for t in outs_1b.logits.argmax(-1)[
|
|
0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1] - part_a
|
|
]
|
|
]
|
|
self.assertEqual(decoded, decoded_1b)
|
|
|
|
def test_stacked_causal_mask_static_cache(self):
|
|
"""same as above but with StaticCache"""
|
|
(
|
|
input_ids,
|
|
position_ids,
|
|
input_ids_shared_prefix,
|
|
mask_shared_prefix,
|
|
position_ids_shared_prefix,
|
|
) = self.get_test_data()
|
|
|
|
# regular batch
|
|
logits = self.model.forward(input_ids, position_ids=position_ids).logits
|
|
logits_last = logits[:, -1, :] # last tokens in each batch line
|
|
decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)]
|
|
|
|
# upgrade the model with StaticCache
|
|
max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1]
|
|
past_key_values = StaticCache(
|
|
config=self.model.config,
|
|
max_batch_size=1,
|
|
max_cache_len=max_cache_len,
|
|
device=torch_device,
|
|
dtype=self.model.dtype,
|
|
)
|
|
|
|
padded_attention_mask = torch.nn.functional.pad(
|
|
input=mask_shared_prefix,
|
|
pad=(0, max_cache_len - mask_shared_prefix.shape[-1]),
|
|
mode="constant",
|
|
value=torch.finfo(self.model_dtype).min,
|
|
)
|
|
|
|
# single forward run with 4D custom mask
|
|
logits_shared_prefix = self.model.forward(
|
|
input_ids_shared_prefix,
|
|
attention_mask=padded_attention_mask,
|
|
position_ids=position_ids_shared_prefix,
|
|
cache_position=torch.arange(input_ids_shared_prefix.shape[-1], device=torch_device),
|
|
past_key_values=past_key_values,
|
|
).logits
|
|
logits_shared_prefix_last = logits_shared_prefix[
|
|
0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1], :
|
|
] # last three tokens
|
|
decoded_shared_prefix = [self.tokenizer.decode(t) for t in logits_shared_prefix_last.argmax(dim=-1)]
|
|
|
|
self.assertEqual(decoded, decoded_shared_prefix)
|
|
|
|
def test_partial_stacked_causal_mask_static_cache(self):
|
|
# Same as the test above, but the input is passed in two groups. It tests that we can pass partial 4D attention masks
|
|
# we pass a 4D attention mask shaped [..., seq_len, full_static_cache_len])
|
|
(
|
|
input_ids,
|
|
position_ids,
|
|
input_ids_shared_prefix,
|
|
mask_shared_prefix,
|
|
position_ids_shared_prefix,
|
|
) = self.get_test_data()
|
|
|
|
# regular batch
|
|
logits = self.model.forward(input_ids, position_ids=position_ids).logits
|
|
logits_last = logits[:, -1, :] # last tokens in each batch line
|
|
decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)]
|
|
|
|
# upgrade the model with StaticCache
|
|
max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1]
|
|
past_key_values = StaticCache(
|
|
config=self.model.config,
|
|
max_batch_size=1,
|
|
max_cache_len=max_cache_len,
|
|
device=torch_device,
|
|
dtype=self.model.dtype,
|
|
)
|
|
|
|
# forward run for the first part of input
|
|
part_a = 3 # split point
|
|
|
|
input_1a = input_ids_shared_prefix[:, :part_a]
|
|
position_ids_1a = position_ids_shared_prefix[:, :part_a]
|
|
mask_1a = mask_shared_prefix[:, :, :part_a, :part_a]
|
|
|
|
padded_mask_1a = torch.nn.functional.pad(
|
|
input=mask_1a,
|
|
pad=(0, max_cache_len - mask_1a.shape[-1]),
|
|
mode="constant",
|
|
value=torch.finfo(self.model_dtype).min,
|
|
)
|
|
|
|
_ = self.model.forward(
|
|
input_1a,
|
|
attention_mask=padded_mask_1a,
|
|
position_ids=position_ids_1a,
|
|
cache_position=torch.arange(part_a, device=torch_device),
|
|
past_key_values=past_key_values,
|
|
)
|
|
|
|
# forward run for the second part of input
|
|
input_1b = input_ids_shared_prefix[:, part_a:]
|
|
position_ids_1b = position_ids_shared_prefix[:, part_a:]
|
|
mask_1b = mask_shared_prefix[:, :, part_a:, :]
|
|
|
|
padded_mask_1b = torch.nn.functional.pad(
|
|
input=mask_1b, pad=(0, max_cache_len - mask_1b.shape[-1]), mode="constant", value=0
|
|
)
|
|
|
|
outs_1b = self.model.forward(
|
|
input_1b,
|
|
attention_mask=padded_mask_1b,
|
|
position_ids=position_ids_1b,
|
|
cache_position=torch.arange(
|
|
part_a,
|
|
input_ids_shared_prefix.shape[-1],
|
|
device=torch_device,
|
|
),
|
|
past_key_values=past_key_values,
|
|
)
|
|
decoded_1b = [
|
|
self.tokenizer.decode(t)
|
|
for t in outs_1b.logits.argmax(-1)[
|
|
0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1] - part_a
|
|
]
|
|
]
|
|
self.assertEqual(decoded, decoded_1b)
|