535 lines
23 KiB
Python
535 lines
23 KiB
Python
# coding=utf-8
|
|
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Testing suite for the PyTorch Persimmon model."""
|
|
|
|
import gc
|
|
import unittest
|
|
|
|
from parameterized import parameterized
|
|
|
|
from transformers import PersimmonConfig, is_torch_available, set_seed
|
|
from transformers.testing_utils import (
|
|
backend_empty_cache,
|
|
require_bitsandbytes,
|
|
require_torch,
|
|
require_torch_accelerator,
|
|
require_torch_fp16,
|
|
slow,
|
|
torch_device,
|
|
)
|
|
|
|
from ...generation.test_utils import GenerationTesterMixin
|
|
from ...test_configuration_common import ConfigTester
|
|
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
|
from ...test_pipeline_mixin import PipelineTesterMixin
|
|
|
|
|
|
if is_torch_available():
|
|
import torch
|
|
|
|
from transformers import (
|
|
AutoTokenizer,
|
|
PersimmonForCausalLM,
|
|
PersimmonForSequenceClassification,
|
|
PersimmonForTokenClassification,
|
|
PersimmonModel,
|
|
)
|
|
from transformers.models.persimmon.modeling_persimmon import (
|
|
PersimmonDynamicNTKScalingRotaryEmbedding,
|
|
PersimmonLinearScalingRotaryEmbedding,
|
|
PersimmonRotaryEmbedding,
|
|
)
|
|
|
|
|
|
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester with Llama->Persimmon
|
|
class PersimmonModelTester:
|
|
def __init__(
|
|
self,
|
|
parent,
|
|
batch_size=13,
|
|
seq_length=7,
|
|
is_training=True,
|
|
use_input_mask=True,
|
|
use_token_type_ids=False,
|
|
use_labels=True,
|
|
vocab_size=99,
|
|
hidden_size=32,
|
|
num_hidden_layers=2,
|
|
num_attention_heads=4,
|
|
intermediate_size=37,
|
|
hidden_act="gelu",
|
|
hidden_dropout_prob=0.1,
|
|
attention_probs_dropout_prob=0.1,
|
|
max_position_embeddings=512,
|
|
type_vocab_size=16,
|
|
type_sequence_label_size=2,
|
|
initializer_range=0.02,
|
|
num_labels=3,
|
|
num_choices=4,
|
|
pad_token_id=0,
|
|
scope=None,
|
|
):
|
|
self.parent = parent
|
|
self.batch_size = batch_size
|
|
self.seq_length = seq_length
|
|
self.is_training = is_training
|
|
self.use_input_mask = use_input_mask
|
|
self.use_token_type_ids = use_token_type_ids
|
|
self.use_labels = use_labels
|
|
self.vocab_size = vocab_size
|
|
self.hidden_size = hidden_size
|
|
self.num_hidden_layers = num_hidden_layers
|
|
self.num_attention_heads = num_attention_heads
|
|
self.intermediate_size = intermediate_size
|
|
self.hidden_act = hidden_act
|
|
self.hidden_dropout_prob = hidden_dropout_prob
|
|
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
|
self.max_position_embeddings = max_position_embeddings
|
|
self.type_vocab_size = type_vocab_size
|
|
self.type_sequence_label_size = type_sequence_label_size
|
|
self.initializer_range = initializer_range
|
|
self.num_labels = num_labels
|
|
self.num_choices = num_choices
|
|
self.pad_token_id = pad_token_id
|
|
self.scope = scope
|
|
|
|
def prepare_config_and_inputs(self):
|
|
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
|
|
|
input_mask = None
|
|
if self.use_input_mask:
|
|
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
|
|
|
|
token_type_ids = None
|
|
if self.use_token_type_ids:
|
|
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
|
|
|
sequence_labels = None
|
|
token_labels = None
|
|
choice_labels = None
|
|
if self.use_labels:
|
|
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
|
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
|
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
|
|
|
config = self.get_config()
|
|
|
|
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
|
|
def get_config(self):
|
|
return PersimmonConfig(
|
|
vocab_size=self.vocab_size,
|
|
hidden_size=self.hidden_size,
|
|
num_hidden_layers=self.num_hidden_layers,
|
|
num_attention_heads=self.num_attention_heads,
|
|
intermediate_size=self.intermediate_size,
|
|
hidden_act=self.hidden_act,
|
|
hidden_dropout_prob=self.hidden_dropout_prob,
|
|
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
|
max_position_embeddings=self.max_position_embeddings,
|
|
type_vocab_size=self.type_vocab_size,
|
|
is_decoder=False,
|
|
initializer_range=self.initializer_range,
|
|
pad_token_id=self.pad_token_id,
|
|
)
|
|
|
|
def create_and_check_model(
|
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
):
|
|
model = PersimmonModel(config=config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
result = model(input_ids, attention_mask=input_mask)
|
|
result = model(input_ids)
|
|
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
|
|
|
def create_and_check_model_as_decoder(
|
|
self,
|
|
config,
|
|
input_ids,
|
|
token_type_ids,
|
|
input_mask,
|
|
sequence_labels,
|
|
token_labels,
|
|
choice_labels,
|
|
encoder_hidden_states,
|
|
encoder_attention_mask,
|
|
):
|
|
config.add_cross_attention = True
|
|
model = PersimmonModel(config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
result = model(
|
|
input_ids,
|
|
attention_mask=input_mask,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
)
|
|
result = model(
|
|
input_ids,
|
|
attention_mask=input_mask,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
)
|
|
result = model(input_ids, attention_mask=input_mask)
|
|
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
|
|
|
def create_and_check_for_causal_lm(
|
|
self,
|
|
config,
|
|
input_ids,
|
|
token_type_ids,
|
|
input_mask,
|
|
sequence_labels,
|
|
token_labels,
|
|
choice_labels,
|
|
encoder_hidden_states,
|
|
encoder_attention_mask,
|
|
):
|
|
model = PersimmonForCausalLM(config=config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
|
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
|
|
|
def create_and_check_decoder_model_past_large_inputs(
|
|
self,
|
|
config,
|
|
input_ids,
|
|
token_type_ids,
|
|
input_mask,
|
|
sequence_labels,
|
|
token_labels,
|
|
choice_labels,
|
|
encoder_hidden_states,
|
|
encoder_attention_mask,
|
|
):
|
|
config.is_decoder = True
|
|
config.add_cross_attention = True
|
|
model = PersimmonForCausalLM(config=config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
|
|
# first forward pass
|
|
outputs = model(
|
|
input_ids,
|
|
attention_mask=input_mask,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
use_cache=True,
|
|
)
|
|
past_key_values = outputs.past_key_values
|
|
|
|
# create hypothetical multiple next token and extent to next_input_ids
|
|
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
|
|
next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)
|
|
|
|
# append to next input_ids and
|
|
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
|
next_attention_mask = torch.cat([input_mask, next_mask], dim=-1)
|
|
|
|
output_from_no_past = model(
|
|
next_input_ids,
|
|
attention_mask=next_attention_mask,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
output_hidden_states=True,
|
|
)["hidden_states"][0]
|
|
output_from_past = model(
|
|
next_tokens,
|
|
attention_mask=next_attention_mask,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
past_key_values=past_key_values,
|
|
output_hidden_states=True,
|
|
)["hidden_states"][0]
|
|
|
|
# select random slice
|
|
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
|
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
|
|
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
|
|
|
|
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
|
|
|
|
# test that outputs are equal for slice
|
|
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
|
|
|
def prepare_config_and_inputs_for_common(self):
|
|
config_and_inputs = self.prepare_config_and_inputs()
|
|
(
|
|
config,
|
|
input_ids,
|
|
token_type_ids,
|
|
input_mask,
|
|
sequence_labels,
|
|
token_labels,
|
|
choice_labels,
|
|
) = config_and_inputs
|
|
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
|
return config, inputs_dict
|
|
|
|
|
|
@require_torch
|
|
class PersimmonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|
all_model_classes = (
|
|
(PersimmonModel, PersimmonForCausalLM, PersimmonForSequenceClassification, PersimmonForTokenClassification)
|
|
if is_torch_available()
|
|
else ()
|
|
)
|
|
pipeline_model_mapping = (
|
|
{
|
|
"feature-extraction": PersimmonModel,
|
|
"text-classification": PersimmonForSequenceClassification,
|
|
"token-classification": PersimmonForTokenClassification,
|
|
# TODO (ydshieh): check why these two fail. Fix them or skip them in a better way.
|
|
# "text-generation": PersimmonForCausalLM,
|
|
# "zero-shot": PersimmonForSequenceClassification,
|
|
}
|
|
if is_torch_available()
|
|
else {}
|
|
)
|
|
|
|
all_generative_model_classes = (PersimmonForCausalLM,) if is_torch_available() else ()
|
|
test_headmasking = False
|
|
test_pruning = False
|
|
|
|
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.setUp with Llama->Persimmon
|
|
def setUp(self):
|
|
self.model_tester = PersimmonModelTester(self)
|
|
self.config_tester = ConfigTester(self, config_class=PersimmonConfig, hidden_size=37)
|
|
|
|
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_config
|
|
def test_config(self):
|
|
self.config_tester.run_common_tests()
|
|
|
|
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model
|
|
def test_model(self):
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
|
|
|
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_various_embeddings
|
|
def test_model_various_embeddings(self):
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
for type in ["absolute", "relative_key", "relative_key_query"]:
|
|
config_and_inputs[0].position_embedding_type = type
|
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
|
|
|
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_sequence_classification_model with Llama->Persimmon,llama->persimmon
|
|
def test_persimmon_sequence_classification_model(self):
|
|
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
config.num_labels = 3
|
|
input_ids = input_dict["input_ids"]
|
|
attention_mask = input_ids.ne(1).to(torch_device)
|
|
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
|
|
model = PersimmonForSequenceClassification(config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
|
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
|
|
|
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_sequence_classification_model_for_single_label with Llama->Persimmon,llama->persimmon
|
|
def test_persimmon_sequence_classification_model_for_single_label(self):
|
|
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
config.num_labels = 3
|
|
config.problem_type = "single_label_classification"
|
|
input_ids = input_dict["input_ids"]
|
|
attention_mask = input_ids.ne(1).to(torch_device)
|
|
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
|
|
model = PersimmonForSequenceClassification(config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
|
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
|
|
|
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_sequence_classification_model_for_multi_label with Llama->Persimmon,llama->persimmon
|
|
def test_persimmon_sequence_classification_model_for_multi_label(self):
|
|
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
config.num_labels = 3
|
|
config.problem_type = "multi_label_classification"
|
|
input_ids = input_dict["input_ids"]
|
|
attention_mask = input_ids.ne(1).to(torch_device)
|
|
sequence_labels = ids_tensor(
|
|
[self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size
|
|
).to(torch.float)
|
|
model = PersimmonForSequenceClassification(config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
|
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
|
|
|
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Persimmon,llama->persimmon
|
|
def test_persimmon_token_classification_model(self):
|
|
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
config.num_labels = 3
|
|
input_ids = input_dict["input_ids"]
|
|
attention_mask = input_ids.ne(1).to(torch_device)
|
|
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
|
|
model = PersimmonForTokenClassification(config=config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
|
|
self.assertEqual(
|
|
result.logits.shape,
|
|
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
|
)
|
|
|
|
@unittest.skip("Persimmon buffers include complex numbers, which breaks this test")
|
|
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_save_load_fast_init_from_base
|
|
def test_save_load_fast_init_from_base(self):
|
|
pass
|
|
|
|
@parameterized.expand([("linear",), ("dynamic",)])
|
|
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_rope_scaling_from_config with Llama->Persimmon
|
|
def test_model_rope_scaling_from_config(self, scaling_type):
|
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
|
short_input = ids_tensor([1, 10], config.vocab_size)
|
|
long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size)
|
|
|
|
set_seed(42) # Fixed seed at init time so the two models get the same random weights
|
|
original_model = PersimmonModel(config)
|
|
original_model.to(torch_device)
|
|
original_model.eval()
|
|
original_short_output = original_model(short_input).last_hidden_state
|
|
original_long_output = original_model(long_input).last_hidden_state
|
|
|
|
set_seed(42) # Fixed seed at init time so the two models get the same random weights
|
|
config.rope_scaling = {"type": scaling_type, "factor": 10.0}
|
|
scaled_model = PersimmonModel(config)
|
|
scaled_model.to(torch_device)
|
|
scaled_model.eval()
|
|
scaled_short_output = scaled_model(short_input).last_hidden_state
|
|
scaled_long_output = scaled_model(long_input).last_hidden_state
|
|
|
|
# Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original
|
|
# maximum sequence length, so the outputs for the short input should match.
|
|
if scaling_type == "dynamic":
|
|
self.assertTrue(torch.allclose(original_short_output, scaled_short_output, atol=1e-5))
|
|
else:
|
|
self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5))
|
|
|
|
# The output should be different for long inputs
|
|
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
|
|
|
|
# Copied from tests.models.falcon.test_modeling_falcon.FalconModelTest.test_model_rope_scaling with Falcon->Persimmon
|
|
def test_model_rope_scaling(self):
|
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
|
hidden_size = config.hidden_size
|
|
num_heads = config.num_attention_heads
|
|
head_dim = hidden_size // num_heads
|
|
scaling_factor = 10
|
|
short_input_length = 10
|
|
long_input_length = int(config.max_position_embeddings * 1.5)
|
|
|
|
# Inputs
|
|
x = torch.randn(1, dtype=torch.float32, device=torch_device) # used exlusively to get the dtype and the device
|
|
|
|
# Sanity check original RoPE
|
|
original_rope = PersimmonRotaryEmbedding(
|
|
head_dim,
|
|
max_position_embeddings=config.max_position_embeddings,
|
|
base=config.rope_theta,
|
|
).to(torch_device)
|
|
original_cos_short, original_sin_short = original_rope(x, short_input_length)
|
|
original_cos_long, original_sin_long = original_rope(x, long_input_length)
|
|
torch.testing.assert_close(original_cos_short, original_cos_long[:short_input_length, :])
|
|
torch.testing.assert_close(original_sin_short, original_sin_long[:short_input_length, :])
|
|
|
|
# Sanity check linear RoPE scaling
|
|
# New position "x" should match original position with index "x/scaling_factor"
|
|
linear_scaling_rope = PersimmonLinearScalingRotaryEmbedding(
|
|
head_dim,
|
|
max_position_embeddings=config.max_position_embeddings,
|
|
base=config.rope_theta,
|
|
scaling_factor=scaling_factor,
|
|
).to(torch_device)
|
|
linear_cos_short, linear_sin_short = linear_scaling_rope(x, short_input_length)
|
|
linear_cos_long, linear_sin_long = linear_scaling_rope(x, long_input_length)
|
|
torch.testing.assert_close(linear_cos_short, linear_cos_long[:short_input_length, :])
|
|
torch.testing.assert_close(linear_sin_short, linear_sin_long[:short_input_length, :])
|
|
for new_position in range(0, long_input_length, scaling_factor):
|
|
original_position = int(new_position // scaling_factor)
|
|
torch.testing.assert_close(linear_cos_long[new_position, :], original_cos_long[original_position, :])
|
|
torch.testing.assert_close(linear_sin_long[new_position, :], original_sin_long[original_position, :])
|
|
|
|
# Sanity check Dynamic NTK RoPE scaling
|
|
# Scaling should only be observed after a long input is fed. We can observe that the frequencies increase
|
|
# with scaling_factor (or that `inv_freq` decreases)
|
|
ntk_scaling_rope = PersimmonDynamicNTKScalingRotaryEmbedding(
|
|
head_dim,
|
|
max_position_embeddings=config.max_position_embeddings,
|
|
base=config.rope_theta,
|
|
scaling_factor=scaling_factor,
|
|
).to(torch_device)
|
|
ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, short_input_length)
|
|
ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, long_input_length)
|
|
torch.testing.assert_close(ntk_cos_short, original_cos_short)
|
|
torch.testing.assert_close(ntk_sin_short, original_sin_short)
|
|
with self.assertRaises(AssertionError):
|
|
torch.testing.assert_close(ntk_cos_long, original_cos_long)
|
|
with self.assertRaises(AssertionError):
|
|
torch.testing.assert_close(ntk_sin_long, original_sin_long)
|
|
self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all())
|
|
|
|
|
|
@require_torch
|
|
class PersimmonIntegrationTest(unittest.TestCase):
|
|
@slow
|
|
@require_torch_accelerator
|
|
@require_bitsandbytes
|
|
def test_model_8b_chat_logits(self):
|
|
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338]
|
|
model = PersimmonForCausalLM.from_pretrained(
|
|
"adept/persimmon-8b-chat", load_in_8bit=True, device_map={"": 0}, torch_dtype=torch.float16
|
|
)
|
|
out = model(torch.tensor([input_ids], device=torch_device)).logits
|
|
|
|
EXPECTED_MEAN = torch.tensor(
|
|
[[-11.4726, -11.1495, -11.2694, -11.2223, -10.9452, -11.0663, -11.0031, -11.1028]]
|
|
)
|
|
# change dtype to `torch.float32` before calling `mean` to avoid `nan` values
|
|
torch.testing.assert_close(out.cpu().to(torch.float32).mean(-1), EXPECTED_MEAN, atol=1e-4, rtol=1e-4)
|
|
# fmt: off
|
|
EXPECTED_SLICE = torch.tensor(
|
|
[-16.9062, -16.9062, -16.9062, -16.9062, -16.8906, -16.9062, -16.9531, -16.9062, -16.9062, -16.9062, -16.9531, -16.9062, -16.9531, -16.9062, -16.9062, -16.9062, -16.9062, -16.9062, -16.9531, -16.9062, -16.9062, -16.9062, -16.9062, -16.9062, -16.9062, -16.9531, -16.9062, -16.9531, -16.9062, -16.9062],
|
|
dtype=torch.float16
|
|
)
|
|
# fmt: on
|
|
torch.testing.assert_close(out.cpu()[0, 0, :30], EXPECTED_SLICE, atol=1e-5, rtol=1e-5)
|
|
|
|
backend_empty_cache(torch_device)
|
|
del model
|
|
gc.collect()
|
|
|
|
@slow
|
|
@require_torch_accelerator
|
|
@require_torch_fp16
|
|
@require_bitsandbytes
|
|
def test_model_8b_chat_greedy_generation(self):
|
|
EXPECTED_TEXT_COMPLETION = """human: Simply put, the theory of relativity states that?\n\nadept: The theory of relativity states that the laws of physics are the same for all observers, regardless of their relative motion."""
|
|
prompt = "human: Simply put, the theory of relativity states that?\n\nadept:"
|
|
tokenizer = AutoTokenizer.from_pretrained("adept/persimmon-8b-chat", use_fast=False)
|
|
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(torch_device)
|
|
model = PersimmonForCausalLM.from_pretrained(
|
|
"adept/persimmon-8b-chat", load_in_8bit=True, device_map={"": 0}, torch_dtype=torch.float16
|
|
)
|
|
|
|
# greedy generation outputs
|
|
generated_ids = model.generate(input_ids, max_new_tokens=64)
|
|
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
|
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
|
|
|
|
backend_empty_cache(torch_device)
|
|
del model
|
|
gc.collect()
|