635 lines
26 KiB
Python
635 lines
26 KiB
Python
# coding=utf-8
|
||
# Copyright 2024 The Qwen team, Alibaba Group and The HuggingFace Inc. team. All rights reserved.
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
"""Testing suite for the PyTorch Qwen2 model."""
|
||
|
||
import gc
|
||
import tempfile
|
||
import unittest
|
||
|
||
import pytest
|
||
|
||
from transformers import AutoTokenizer, Qwen2Config, is_torch_available, set_seed
|
||
from transformers.testing_utils import (
|
||
backend_empty_cache,
|
||
require_bitsandbytes,
|
||
require_flash_attn,
|
||
require_torch,
|
||
require_torch_gpu,
|
||
require_torch_sdpa,
|
||
slow,
|
||
torch_device,
|
||
)
|
||
|
||
from ...generation.test_utils import GenerationTesterMixin
|
||
from ...test_configuration_common import ConfigTester
|
||
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
||
from ...test_pipeline_mixin import PipelineTesterMixin
|
||
|
||
|
||
if is_torch_available():
|
||
import torch
|
||
|
||
from transformers import (
|
||
Qwen2ForCausalLM,
|
||
Qwen2ForSequenceClassification,
|
||
Qwen2ForTokenClassification,
|
||
Qwen2Model,
|
||
)
|
||
|
||
|
||
class Qwen2ModelTester:
|
||
def __init__(
|
||
self,
|
||
parent,
|
||
batch_size=13,
|
||
seq_length=7,
|
||
is_training=True,
|
||
use_input_mask=True,
|
||
use_token_type_ids=True,
|
||
use_labels=True,
|
||
vocab_size=99,
|
||
hidden_size=32,
|
||
num_hidden_layers=5,
|
||
max_window_layers=3,
|
||
use_sliding_window=True,
|
||
sliding_window=2,
|
||
num_attention_heads=4,
|
||
num_key_value_heads=2,
|
||
intermediate_size=37,
|
||
hidden_act="gelu",
|
||
hidden_dropout_prob=0.1,
|
||
attention_probs_dropout_prob=0.1,
|
||
max_position_embeddings=512,
|
||
type_vocab_size=16,
|
||
type_sequence_label_size=2,
|
||
initializer_range=0.02,
|
||
num_labels=3,
|
||
num_choices=4,
|
||
pad_token_id=0,
|
||
bos_token_id=1,
|
||
scope=None,
|
||
):
|
||
self.parent = parent
|
||
self.batch_size = batch_size
|
||
self.seq_length = seq_length
|
||
self.is_training = is_training
|
||
self.use_input_mask = use_input_mask
|
||
self.use_token_type_ids = use_token_type_ids
|
||
self.use_labels = use_labels
|
||
self.vocab_size = vocab_size
|
||
self.hidden_size = hidden_size
|
||
self.num_hidden_layers = num_hidden_layers
|
||
self.max_window_layers = max_window_layers
|
||
self.use_sliding_window = use_sliding_window
|
||
self.sliding_window = sliding_window
|
||
self.num_attention_heads = num_attention_heads
|
||
self.num_key_value_heads = num_key_value_heads
|
||
self.intermediate_size = intermediate_size
|
||
self.hidden_act = hidden_act
|
||
self.hidden_dropout_prob = hidden_dropout_prob
|
||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||
self.max_position_embeddings = max_position_embeddings
|
||
self.type_vocab_size = type_vocab_size
|
||
self.type_sequence_label_size = type_sequence_label_size
|
||
self.initializer_range = initializer_range
|
||
self.num_labels = num_labels
|
||
self.num_choices = num_choices
|
||
self.pad_token_id = pad_token_id
|
||
self.bos_token_id = bos_token_id
|
||
self.scope = scope
|
||
|
||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs
|
||
def prepare_config_and_inputs(self):
|
||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||
|
||
input_mask = None
|
||
if self.use_input_mask:
|
||
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
|
||
|
||
token_type_ids = None
|
||
if self.use_token_type_ids:
|
||
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
||
|
||
sequence_labels = None
|
||
token_labels = None
|
||
choice_labels = None
|
||
if self.use_labels:
|
||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
||
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
||
|
||
config = self.get_config()
|
||
|
||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||
|
||
def get_config(self):
|
||
return Qwen2Config(
|
||
vocab_size=self.vocab_size,
|
||
hidden_size=self.hidden_size,
|
||
num_hidden_layers=self.num_hidden_layers,
|
||
max_window_layers=self.max_window_layers,
|
||
use_sliding_window=self.use_sliding_window,
|
||
sliding_window=self.sliding_window,
|
||
num_attention_heads=self.num_attention_heads,
|
||
num_key_value_heads=self.num_key_value_heads,
|
||
intermediate_size=self.intermediate_size,
|
||
hidden_act=self.hidden_act,
|
||
hidden_dropout_prob=self.hidden_dropout_prob,
|
||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||
max_position_embeddings=self.max_position_embeddings,
|
||
type_vocab_size=self.type_vocab_size,
|
||
is_decoder=False,
|
||
initializer_range=self.initializer_range,
|
||
pad_token_id=self.pad_token_id,
|
||
bos_token_id=self.bos_token_id,
|
||
)
|
||
|
||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model with Llama->Qwen2
|
||
def create_and_check_model(
|
||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||
):
|
||
model = Qwen2Model(config=config)
|
||
model.to(torch_device)
|
||
model.eval()
|
||
result = model(input_ids, attention_mask=input_mask)
|
||
result = model(input_ids)
|
||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||
|
||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model_as_decoder with Llama->Qwen2
|
||
def create_and_check_model_as_decoder(
|
||
self,
|
||
config,
|
||
input_ids,
|
||
token_type_ids,
|
||
input_mask,
|
||
sequence_labels,
|
||
token_labels,
|
||
choice_labels,
|
||
encoder_hidden_states,
|
||
encoder_attention_mask,
|
||
):
|
||
config.add_cross_attention = True
|
||
model = Qwen2Model(config)
|
||
model.to(torch_device)
|
||
model.eval()
|
||
result = model(
|
||
input_ids,
|
||
attention_mask=input_mask,
|
||
encoder_hidden_states=encoder_hidden_states,
|
||
encoder_attention_mask=encoder_attention_mask,
|
||
)
|
||
result = model(
|
||
input_ids,
|
||
attention_mask=input_mask,
|
||
encoder_hidden_states=encoder_hidden_states,
|
||
)
|
||
result = model(input_ids, attention_mask=input_mask)
|
||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||
|
||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_for_causal_lm with Llama->Qwen2
|
||
def create_and_check_for_causal_lm(
|
||
self,
|
||
config,
|
||
input_ids,
|
||
token_type_ids,
|
||
input_mask,
|
||
sequence_labels,
|
||
token_labels,
|
||
choice_labels,
|
||
encoder_hidden_states,
|
||
encoder_attention_mask,
|
||
):
|
||
model = Qwen2ForCausalLM(config=config)
|
||
model.to(torch_device)
|
||
model.eval()
|
||
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
|
||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||
|
||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_decoder_model_past_large_inputs with Llama->Qwen2
|
||
def create_and_check_decoder_model_past_large_inputs(
|
||
self,
|
||
config,
|
||
input_ids,
|
||
token_type_ids,
|
||
input_mask,
|
||
sequence_labels,
|
||
token_labels,
|
||
choice_labels,
|
||
encoder_hidden_states,
|
||
encoder_attention_mask,
|
||
):
|
||
config.is_decoder = True
|
||
config.add_cross_attention = True
|
||
model = Qwen2ForCausalLM(config=config)
|
||
model.to(torch_device)
|
||
model.eval()
|
||
|
||
# first forward pass
|
||
outputs = model(
|
||
input_ids,
|
||
attention_mask=input_mask,
|
||
encoder_hidden_states=encoder_hidden_states,
|
||
encoder_attention_mask=encoder_attention_mask,
|
||
use_cache=True,
|
||
)
|
||
past_key_values = outputs.past_key_values
|
||
|
||
# create hypothetical multiple next token and extent to next_input_ids
|
||
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
|
||
next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)
|
||
|
||
# append to next input_ids and
|
||
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||
next_attention_mask = torch.cat([input_mask, next_mask], dim=-1)
|
||
|
||
output_from_no_past = model(
|
||
next_input_ids,
|
||
attention_mask=next_attention_mask,
|
||
encoder_hidden_states=encoder_hidden_states,
|
||
encoder_attention_mask=encoder_attention_mask,
|
||
output_hidden_states=True,
|
||
)["hidden_states"][0]
|
||
output_from_past = model(
|
||
next_tokens,
|
||
attention_mask=next_attention_mask,
|
||
encoder_hidden_states=encoder_hidden_states,
|
||
encoder_attention_mask=encoder_attention_mask,
|
||
past_key_values=past_key_values,
|
||
output_hidden_states=True,
|
||
)["hidden_states"][0]
|
||
|
||
# select random slice
|
||
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
|
||
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
|
||
|
||
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
|
||
|
||
# test that outputs are equal for slice
|
||
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||
|
||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs_for_common
|
||
def prepare_config_and_inputs_for_common(self):
|
||
config_and_inputs = self.prepare_config_and_inputs()
|
||
(
|
||
config,
|
||
input_ids,
|
||
token_type_ids,
|
||
input_mask,
|
||
sequence_labels,
|
||
token_labels,
|
||
choice_labels,
|
||
) = config_and_inputs
|
||
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
||
return config, inputs_dict
|
||
|
||
|
||
@require_torch
|
||
# Copied from tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Qwen2
|
||
class Qwen2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||
all_model_classes = (
|
||
(Qwen2Model, Qwen2ForCausalLM, Qwen2ForSequenceClassification, Qwen2ForTokenClassification)
|
||
if is_torch_available()
|
||
else ()
|
||
)
|
||
all_generative_model_classes = (Qwen2ForCausalLM,) if is_torch_available() else ()
|
||
pipeline_model_mapping = (
|
||
{
|
||
"feature-extraction": Qwen2Model,
|
||
"text-classification": Qwen2ForSequenceClassification,
|
||
"token-classification": Qwen2ForTokenClassification,
|
||
"text-generation": Qwen2ForCausalLM,
|
||
"zero-shot": Qwen2ForSequenceClassification,
|
||
}
|
||
if is_torch_available()
|
||
else {}
|
||
)
|
||
test_headmasking = False
|
||
test_pruning = False
|
||
fx_compatible = True
|
||
|
||
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
|
||
def is_pipeline_test_to_skip(
|
||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||
):
|
||
return True
|
||
|
||
# Ignore copy
|
||
# TODO: @Fxmarty
|
||
@require_torch_sdpa
|
||
@slow
|
||
@unittest.skip(reason="Currently failing.")
|
||
def test_eager_matches_sdpa_generate(self):
|
||
super().test_eager_matches_sdpa_generate()
|
||
|
||
def setUp(self):
|
||
self.model_tester = Qwen2ModelTester(self)
|
||
self.config_tester = ConfigTester(self, config_class=Qwen2Config, hidden_size=37)
|
||
|
||
def test_config(self):
|
||
self.config_tester.run_common_tests()
|
||
|
||
def test_model(self):
|
||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||
|
||
def test_model_various_embeddings(self):
|
||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||
for type in ["absolute", "relative_key", "relative_key_query"]:
|
||
config_and_inputs[0].position_embedding_type = type
|
||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||
|
||
def test_Qwen2_sequence_classification_model(self):
|
||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||
print(config)
|
||
config.num_labels = 3
|
||
input_ids = input_dict["input_ids"]
|
||
attention_mask = input_ids.ne(1).to(torch_device)
|
||
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
|
||
model = Qwen2ForSequenceClassification(config)
|
||
model.to(torch_device)
|
||
model.eval()
|
||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||
|
||
def test_Qwen2_sequence_classification_model_for_single_label(self):
|
||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||
config.num_labels = 3
|
||
config.problem_type = "single_label_classification"
|
||
input_ids = input_dict["input_ids"]
|
||
attention_mask = input_ids.ne(1).to(torch_device)
|
||
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
|
||
model = Qwen2ForSequenceClassification(config)
|
||
model.to(torch_device)
|
||
model.eval()
|
||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||
|
||
def test_Qwen2_sequence_classification_model_for_multi_label(self):
|
||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||
config.num_labels = 3
|
||
config.problem_type = "multi_label_classification"
|
||
input_ids = input_dict["input_ids"]
|
||
attention_mask = input_ids.ne(1).to(torch_device)
|
||
sequence_labels = ids_tensor(
|
||
[self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size
|
||
).to(torch.float)
|
||
model = Qwen2ForSequenceClassification(config)
|
||
model.to(torch_device)
|
||
model.eval()
|
||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||
|
||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Qwen2,llama->Qwen2
|
||
def test_Qwen2_token_classification_model(self):
|
||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||
config.num_labels = 3
|
||
input_ids = input_dict["input_ids"]
|
||
attention_mask = input_ids.ne(1).to(torch_device)
|
||
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
|
||
model = Qwen2ForTokenClassification(config=config)
|
||
model.to(torch_device)
|
||
model.eval()
|
||
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
|
||
self.assertEqual(
|
||
result.logits.shape,
|
||
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
||
)
|
||
|
||
@unittest.skip("Qwen2 buffers include complex numbers, which breaks this test")
|
||
def test_save_load_fast_init_from_base(self):
|
||
pass
|
||
|
||
@unittest.skip("Qwen2 uses GQA on all models so the KV cache is a non standard format")
|
||
def test_past_key_values_format(self):
|
||
pass
|
||
|
||
@require_flash_attn
|
||
@require_torch_gpu
|
||
@pytest.mark.flash_attn_test
|
||
@slow
|
||
def test_flash_attn_2_generate_padding_right(self):
|
||
import torch
|
||
|
||
for model_class in self.all_generative_model_classes:
|
||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||
model = model_class(config)
|
||
|
||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||
model.save_pretrained(tmpdirname)
|
||
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
|
||
torch_device
|
||
)
|
||
|
||
dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
|
||
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device)
|
||
|
||
model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False)
|
||
|
||
model = model_class.from_pretrained(
|
||
tmpdirname,
|
||
torch_dtype=torch.float16,
|
||
attn_implementation="flash_attention_2",
|
||
low_cpu_mem_usage=True,
|
||
).to(torch_device)
|
||
|
||
with self.assertRaises(ValueError):
|
||
_ = model.generate(
|
||
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
|
||
)
|
||
|
||
@require_flash_attn
|
||
@require_torch_gpu
|
||
@pytest.mark.flash_attn_test
|
||
@slow
|
||
def test_flash_attn_2_generate_use_cache(self):
|
||
import torch
|
||
|
||
max_new_tokens = 30
|
||
|
||
for model_class in self.all_generative_model_classes:
|
||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||
|
||
dummy_input = inputs_dict[model_class.main_input_name]
|
||
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
|
||
dummy_input = dummy_input.to(torch.float16)
|
||
|
||
# make sure that all models have enough positions for generation
|
||
if hasattr(config, "max_position_embeddings"):
|
||
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
|
||
|
||
model = model_class(config)
|
||
|
||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||
model.save_pretrained(tmpdirname)
|
||
|
||
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
|
||
# NOTE: Qwen2 apparently does not support right padding + use_cache with FA2.
|
||
dummy_attention_mask[:, -1] = 1
|
||
|
||
model = model_class.from_pretrained(
|
||
tmpdirname,
|
||
torch_dtype=torch.float16,
|
||
attn_implementation="flash_attention_2",
|
||
low_cpu_mem_usage=True,
|
||
).to(torch_device)
|
||
|
||
# Just test that a large cache works as expected
|
||
_ = model.generate(
|
||
dummy_input,
|
||
attention_mask=dummy_attention_mask,
|
||
max_new_tokens=max_new_tokens,
|
||
do_sample=False,
|
||
use_cache=True,
|
||
)
|
||
|
||
@require_flash_attn
|
||
@require_torch_gpu
|
||
@pytest.mark.flash_attn_test
|
||
@slow
|
||
def test_flash_attn_2_inference_equivalence_right_padding(self):
|
||
self.skipTest("Qwen2 flash attention does not support right padding")
|
||
|
||
|
||
@require_torch
|
||
class Qwen2IntegrationTest(unittest.TestCase):
|
||
@slow
|
||
def test_model_450m_logits(self):
|
||
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338]
|
||
model = Qwen2ForCausalLM.from_pretrained("Qwen/Qwen2-450m-beta", device_map="auto")
|
||
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
|
||
with torch.no_grad():
|
||
out = model(input_ids).logits.cpu()
|
||
# Expected mean on dim = -1
|
||
EXPECTED_MEAN = torch.tensor([[-2.5548, -2.5737, -3.0600, -2.5906, -2.8478, -2.8118, -2.9325, -2.7694]])
|
||
torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, atol=1e-2, rtol=1e-2)
|
||
# slicing logits[0, 0, 0:30]
|
||
EXPECTED_SLICE = torch.tensor([-5.8781, -5.8616, -0.1052, -4.7200, -5.8781, -5.8774, -5.8773, -5.8777, -5.8781, -5.8780, -5.8781, -5.8779, -1.0787, 1.7583, -5.8779, -5.8780, -5.8783, -5.8778, -5.8776, -5.8781, -5.8784, -5.8778, -5.8778, -5.8777, -5.8779, -5.8778, -5.8776, -5.8780, -5.8779, -5.8781]) # fmt: skip
|
||
print(out[0, 0, :30])
|
||
torch.testing.assert_close(out[0, 0, :30], EXPECTED_SLICE, atol=1e-4, rtol=1e-4)
|
||
|
||
del model
|
||
backend_empty_cache(torch_device)
|
||
gc.collect()
|
||
|
||
@slow
|
||
def test_model_450m_generation(self):
|
||
EXPECTED_TEXT_COMPLETION = """My favourite condiment is 100% ketchup. I love it on everything. I’m not a big"""
|
||
prompt = "My favourite condiment is "
|
||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-450m-beta", use_fast=False)
|
||
model = Qwen2ForCausalLM.from_pretrained("Qwen/Qwen2-450m-beta", device_map="auto")
|
||
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.model.embed_tokens.weight.device)
|
||
|
||
# greedy generation outputs
|
||
generated_ids = model.generate(input_ids, max_new_tokens=20, temperature=0)
|
||
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
||
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
|
||
|
||
del model
|
||
backend_empty_cache(torch_device)
|
||
gc.collect()
|
||
|
||
@require_bitsandbytes
|
||
@slow
|
||
@require_flash_attn
|
||
def test_model_450m_long_prompt(self):
|
||
EXPECTED_OUTPUT_TOKEN_IDS = [306, 338]
|
||
# An input with 4097 tokens that is above the size of the sliding window
|
||
input_ids = [1] + [306, 338] * 2048
|
||
model = Qwen2ForCausalLM.from_pretrained(
|
||
"Qwen/Qwen2-450m-beta",
|
||
device_map="auto",
|
||
load_in_4bit=True,
|
||
attn_implementation="flash_attention_2",
|
||
)
|
||
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
|
||
generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0)
|
||
self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist())
|
||
|
||
# Assisted generation
|
||
assistant_model = model
|
||
assistant_model.generation_config.num_assistant_tokens = 2
|
||
assistant_model.generation_config.num_assistant_tokens_schedule = "constant"
|
||
generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0)
|
||
self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist())
|
||
|
||
del assistant_model
|
||
del model
|
||
backend_empty_cache(torch_device)
|
||
gc.collect()
|
||
|
||
@slow
|
||
@require_torch_sdpa
|
||
def test_model_450m_long_prompt_sdpa(self):
|
||
EXPECTED_OUTPUT_TOKEN_IDS = [306, 338]
|
||
# An input with 4097 tokens that is above the size of the sliding window
|
||
input_ids = [1] + [306, 338] * 2048
|
||
model = Qwen2ForCausalLM.from_pretrained(
|
||
"Qwen/Qwen2-450m-beta",
|
||
device_map="auto",
|
||
attn_implementation="sdpa",
|
||
)
|
||
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
|
||
generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0)
|
||
self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist())
|
||
|
||
# Assisted generation
|
||
assistant_model = model
|
||
assistant_model.generation_config.num_assistant_tokens = 2
|
||
assistant_model.generation_config.num_assistant_tokens_schedule = "constant"
|
||
generated_ids = assistant_model.generate(input_ids, max_new_tokens=4, temperature=0)
|
||
self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist())
|
||
|
||
del assistant_model
|
||
|
||
backend_empty_cache(torch_device)
|
||
gc.collect()
|
||
|
||
EXPECTED_TEXT_COMPLETION = """My favourite condiment is 100% ketchup. I love it on everything. I’m not a big"""
|
||
prompt = "My favourite condiment is "
|
||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-450m-beta", use_fast=False)
|
||
|
||
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.model.embed_tokens.weight.device)
|
||
|
||
# greedy generation outputs
|
||
generated_ids = model.generate(input_ids, max_new_tokens=20, temperature=0)
|
||
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
||
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
|
||
|
||
@slow
|
||
def test_speculative_generation(self):
|
||
EXPECTED_TEXT_COMPLETION = (
|
||
"My favourite condiment is 100% Sriracha. I love the heat, the tang and the fact costs"
|
||
)
|
||
prompt = "My favourite condiment is "
|
||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-7B-beta", use_fast=False)
|
||
model = Qwen2ForCausalLM.from_pretrained("Qwen/Qwen2-450m-beta", device_map="auto", torch_dtype=torch.float16)
|
||
assistant_model = Qwen2ForCausalLM.from_pretrained(
|
||
"Qwen/Qwen2-450m-beta", device_map="auto", torch_dtype=torch.float16
|
||
)
|
||
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.model.embed_tokens.weight.device)
|
||
|
||
# greedy generation outputs
|
||
set_seed(0)
|
||
generated_ids = model.generate(
|
||
input_ids, max_new_tokens=20, do_sample=True, temperature=0.3, assistant_model=assistant_model
|
||
)
|
||
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
||
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
|
||
|
||
del model
|
||
backend_empty_cache(torch_device)
|
||
gc.collect()
|