537 lines
21 KiB
Python
537 lines
21 KiB
Python
# coding=utf-8
|
|
# Copyright 2024 JetMoe AI and The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Testing suite for the PyTorch JetMoe model."""
|
|
|
|
import gc
|
|
import tempfile
|
|
import unittest
|
|
|
|
import pytest
|
|
from parameterized import parameterized
|
|
|
|
from transformers import AutoTokenizer, JetMoeConfig, is_torch_available
|
|
from transformers.testing_utils import (
|
|
backend_empty_cache,
|
|
is_flaky,
|
|
require_flash_attn,
|
|
require_torch,
|
|
require_torch_gpu,
|
|
require_torch_sdpa,
|
|
slow,
|
|
torch_device,
|
|
)
|
|
|
|
from ...generation.test_utils import GenerationTesterMixin
|
|
from ...test_configuration_common import ConfigTester
|
|
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
|
from ...test_pipeline_mixin import PipelineTesterMixin
|
|
|
|
|
|
if is_torch_available():
|
|
import torch
|
|
|
|
from transformers import (
|
|
JetMoeForCausalLM,
|
|
JetMoeForSequenceClassification,
|
|
JetMoeModel,
|
|
)
|
|
|
|
|
|
class JetMoeModelTester:
|
|
def __init__(
|
|
self,
|
|
parent,
|
|
batch_size=13,
|
|
seq_length=7,
|
|
is_training=True,
|
|
use_input_mask=True,
|
|
use_token_type_ids=False,
|
|
use_labels=True,
|
|
vocab_size=99,
|
|
hidden_size=32,
|
|
num_hidden_layers=2,
|
|
num_key_value_heads=2,
|
|
kv_channels=8,
|
|
intermediate_size=37,
|
|
hidden_act="silu",
|
|
num_local_experts=4,
|
|
num_experts_per_tok=2,
|
|
max_position_embeddings=512,
|
|
type_vocab_size=16,
|
|
type_sequence_label_size=2,
|
|
initializer_range=0.02,
|
|
num_labels=3,
|
|
num_choices=4,
|
|
pad_token_id=0,
|
|
scope=None,
|
|
):
|
|
self.parent = parent
|
|
self.batch_size = batch_size
|
|
self.seq_length = seq_length
|
|
self.is_training = is_training
|
|
self.use_input_mask = use_input_mask
|
|
self.use_token_type_ids = use_token_type_ids
|
|
self.use_labels = use_labels
|
|
self.vocab_size = vocab_size
|
|
self.hidden_size = hidden_size
|
|
self.num_hidden_layers = num_hidden_layers
|
|
self.kv_channels = kv_channels
|
|
self.num_attention_heads = num_key_value_heads * num_experts_per_tok
|
|
self.num_key_value_heads = num_key_value_heads
|
|
self.intermediate_size = intermediate_size
|
|
self.hidden_act = hidden_act
|
|
self.num_local_experts = num_local_experts
|
|
self.num_experts_per_tok = num_experts_per_tok
|
|
self.max_position_embeddings = max_position_embeddings
|
|
self.type_vocab_size = type_vocab_size
|
|
self.type_sequence_label_size = type_sequence_label_size
|
|
self.initializer_range = initializer_range
|
|
self.num_labels = num_labels
|
|
self.num_choices = num_choices
|
|
self.pad_token_id = pad_token_id
|
|
self.scope = scope
|
|
|
|
def prepare_config_and_inputs(self):
|
|
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
|
|
|
input_mask = None
|
|
if self.use_input_mask:
|
|
input_mask = torch.ones(self.batch_size, self.seq_length).to(torch_device)
|
|
|
|
token_type_ids = None
|
|
if self.use_token_type_ids:
|
|
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
|
|
|
sequence_labels = None
|
|
token_labels = None
|
|
choice_labels = None
|
|
if self.use_labels:
|
|
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
|
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
|
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
|
|
|
config = self.get_config()
|
|
|
|
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
|
|
def get_config(self):
|
|
return JetMoeConfig(
|
|
vocab_size=self.vocab_size,
|
|
hidden_size=self.hidden_size,
|
|
num_hidden_layers=self.num_hidden_layers,
|
|
num_key_value_heads=self.num_key_value_heads,
|
|
kv_channels=self.kv_channels,
|
|
intermediate_size=self.intermediate_size,
|
|
activation_function=self.hidden_act,
|
|
num_local_experts=self.num_local_experts,
|
|
num_experts_per_tok=self.num_experts_per_tok,
|
|
max_position_embeddings=self.max_position_embeddings,
|
|
type_vocab_size=self.type_vocab_size,
|
|
is_decoder=False,
|
|
initializer_range=self.initializer_range,
|
|
pad_token_id=self.pad_token_id,
|
|
)
|
|
|
|
def create_and_check_model(
|
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
):
|
|
model = JetMoeModel(config=config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
result = model(input_ids, attention_mask=input_mask)
|
|
result = model(input_ids)
|
|
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
|
|
|
def create_and_check_model_as_decoder(
|
|
self,
|
|
config,
|
|
input_ids,
|
|
token_type_ids,
|
|
input_mask,
|
|
sequence_labels,
|
|
token_labels,
|
|
choice_labels,
|
|
encoder_hidden_states,
|
|
encoder_attention_mask,
|
|
):
|
|
config.add_cross_attention = True
|
|
model = JetMoeModel(config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
result = model(
|
|
input_ids,
|
|
attention_mask=input_mask,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
)
|
|
result = model(
|
|
input_ids,
|
|
attention_mask=input_mask,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
)
|
|
result = model(input_ids, attention_mask=input_mask)
|
|
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
|
|
|
def create_and_check_for_causal_lm(
|
|
self,
|
|
config,
|
|
input_ids,
|
|
token_type_ids,
|
|
input_mask,
|
|
sequence_labels,
|
|
token_labels,
|
|
choice_labels,
|
|
encoder_hidden_states,
|
|
encoder_attention_mask,
|
|
):
|
|
model = JetMoeForCausalLM(config=config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
|
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
|
|
|
def create_and_check_decoder_model_past_large_inputs(
|
|
self,
|
|
config,
|
|
input_ids,
|
|
token_type_ids,
|
|
input_mask,
|
|
sequence_labels,
|
|
token_labels,
|
|
choice_labels,
|
|
encoder_hidden_states,
|
|
encoder_attention_mask,
|
|
):
|
|
config.is_decoder = True
|
|
config.add_cross_attention = True
|
|
model = JetMoeForCausalLM(config=config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
|
|
# first forward pass
|
|
outputs = model(
|
|
input_ids,
|
|
attention_mask=input_mask,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
use_cache=True,
|
|
)
|
|
past_key_values = outputs.past_key_values
|
|
|
|
# create hypothetical multiple next token and extent to next_input_ids
|
|
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
|
|
next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)
|
|
|
|
# append to next input_ids and
|
|
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
|
next_attention_mask = torch.cat([input_mask, next_mask], dim=-1)
|
|
|
|
output_from_no_past = model(
|
|
next_input_ids,
|
|
attention_mask=next_attention_mask,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
output_hidden_states=True,
|
|
)["hidden_states"][0]
|
|
output_from_past = model(
|
|
next_tokens,
|
|
attention_mask=next_attention_mask,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
past_key_values=past_key_values,
|
|
output_hidden_states=True,
|
|
)["hidden_states"][0]
|
|
|
|
# select random slice
|
|
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
|
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
|
|
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
|
|
|
|
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
|
|
|
|
# test that outputs are equal for slice
|
|
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
|
|
|
def prepare_config_and_inputs_for_common(self):
|
|
config_and_inputs = self.prepare_config_and_inputs()
|
|
(
|
|
config,
|
|
input_ids,
|
|
token_type_ids,
|
|
input_mask,
|
|
sequence_labels,
|
|
token_labels,
|
|
choice_labels,
|
|
) = config_and_inputs
|
|
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
|
return config, inputs_dict
|
|
|
|
|
|
@require_torch
|
|
class JetMoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|
all_model_classes = (
|
|
(JetMoeModel, JetMoeForCausalLM, JetMoeForSequenceClassification) if is_torch_available() else ()
|
|
)
|
|
all_generative_model_classes = (JetMoeForCausalLM,) if is_torch_available() else ()
|
|
pipeline_model_mapping = (
|
|
{
|
|
"feature-extraction": JetMoeModel,
|
|
"text-classification": JetMoeForSequenceClassification,
|
|
"text-generation": JetMoeForCausalLM,
|
|
"zero-shot": JetMoeForSequenceClassification,
|
|
}
|
|
if is_torch_available()
|
|
else {}
|
|
)
|
|
test_headmasking = False
|
|
test_pruning = False
|
|
test_mismatched_shapes = False
|
|
test_cpu_offload = False
|
|
test_disk_offload_bin = False
|
|
test_disk_offload_safetensors = False
|
|
|
|
# TODO: @Fxmarty
|
|
@is_flaky(max_attempts=3, description="flaky on some models.")
|
|
@require_torch_sdpa
|
|
@slow
|
|
def test_eager_matches_sdpa_generate(self):
|
|
super().test_eager_matches_sdpa_generate()
|
|
|
|
@parameterized.expand([(1, False), (1, True), (4, False)])
|
|
def test_new_cache_format(self, num_beams, do_sample):
|
|
pass
|
|
|
|
def setUp(self):
|
|
self.model_tester = JetMoeModelTester(self)
|
|
self.config_tester = ConfigTester(
|
|
self, config_class=JetMoeConfig, common_properties=["hidden_size", "num_hidden_layers"]
|
|
)
|
|
|
|
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_config
|
|
def test_config(self):
|
|
self.config_tester.run_common_tests()
|
|
|
|
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model
|
|
def test_model(self):
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
|
|
|
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_various_embeddings
|
|
def test_model_various_embeddings(self):
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
for type in ["absolute", "relative_key", "relative_key_query"]:
|
|
config_and_inputs[0].position_embedding_type = type
|
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
|
|
|
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_sequence_classification_model with llama->jetmoe, Llama->JetMoe
|
|
def test_jetmoe_sequence_classification_model(self):
|
|
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
config.num_labels = 3
|
|
input_ids = input_dict["input_ids"]
|
|
attention_mask = input_ids.ne(1).to(torch_device)
|
|
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
|
|
model = JetMoeForSequenceClassification(config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
|
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
|
|
|
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_sequence_classification_model_for_single_label with llama->jetmoe, Llama->JetMoe
|
|
def test_jetmoe_sequence_classification_model_for_single_label(self):
|
|
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
config.num_labels = 3
|
|
config.problem_type = "single_label_classification"
|
|
input_ids = input_dict["input_ids"]
|
|
attention_mask = input_ids.ne(1).to(torch_device)
|
|
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
|
|
model = JetMoeForSequenceClassification(config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
|
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
|
|
|
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_sequence_classification_model_for_multi_label with llama->jetmoe, Llama->JetMoe
|
|
def test_jetmoe_sequence_classification_model_for_multi_label(self):
|
|
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
config.num_labels = 3
|
|
config.problem_type = "multi_label_classification"
|
|
input_ids = input_dict["input_ids"]
|
|
attention_mask = input_ids.ne(1).to(torch_device)
|
|
sequence_labels = ids_tensor(
|
|
[self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size
|
|
).to(torch.float)
|
|
model = JetMoeForSequenceClassification(config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
|
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
|
|
|
@unittest.skip("JetMoe buffers include complex numbers, which breaks this test")
|
|
def test_save_load_fast_init_from_base(self):
|
|
pass
|
|
|
|
@unittest.skip("JetMoe uses MoA on all models so the KV cache is a non standard format")
|
|
def test_past_key_values_format(self):
|
|
pass
|
|
|
|
@require_flash_attn
|
|
@require_torch_gpu
|
|
@pytest.mark.flash_attn_test
|
|
@slow
|
|
def test_flash_attn_2_generate_padding_right(self):
|
|
import torch
|
|
|
|
for model_class in self.all_generative_model_classes:
|
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
|
model = model_class(config)
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
model.save_pretrained(tmpdirname)
|
|
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
|
|
torch_device
|
|
)
|
|
|
|
dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
|
|
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device)
|
|
|
|
model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False)
|
|
|
|
model = model_class.from_pretrained(
|
|
tmpdirname,
|
|
torch_dtype=torch.float16,
|
|
attn_implementation="flash_attention_2",
|
|
low_cpu_mem_usage=True,
|
|
).to(torch_device)
|
|
|
|
with self.assertRaises(ValueError):
|
|
_ = model.generate(
|
|
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
|
|
)
|
|
|
|
@require_flash_attn
|
|
@require_torch_gpu
|
|
@pytest.mark.flash_attn_test
|
|
@slow
|
|
def test_flash_attn_2_generate_use_cache(self):
|
|
import torch
|
|
|
|
max_new_tokens = 30
|
|
|
|
for model_class in self.all_generative_model_classes:
|
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
|
|
dummy_input = inputs_dict[model_class.main_input_name]
|
|
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
|
|
dummy_input = dummy_input.to(torch.float16)
|
|
|
|
# make sure that all models have enough positions for generation
|
|
if hasattr(config, "max_position_embeddings"):
|
|
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
|
|
|
|
model = model_class(config)
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
model.save_pretrained(tmpdirname)
|
|
|
|
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
|
|
# NOTE: JetMoe apparently does not support right padding + use_cache with FA2.
|
|
dummy_attention_mask[:, -1] = 1
|
|
|
|
model = model_class.from_pretrained(
|
|
tmpdirname,
|
|
torch_dtype=torch.float16,
|
|
attn_implementation="flash_attention_2",
|
|
low_cpu_mem_usage=True,
|
|
).to(torch_device)
|
|
|
|
# Just test that a large cache works as expected
|
|
_ = model.generate(
|
|
dummy_input,
|
|
attention_mask=dummy_attention_mask,
|
|
max_new_tokens=max_new_tokens,
|
|
do_sample=False,
|
|
use_cache=True,
|
|
)
|
|
|
|
@require_flash_attn
|
|
@require_torch_gpu
|
|
@pytest.mark.flash_attn_test
|
|
@slow
|
|
def test_flash_attn_2_inference_equivalence_right_padding(self):
|
|
self.skipTest("JetMoe flash attention does not support right padding")
|
|
|
|
|
|
@require_torch
|
|
class JetMoeIntegrationTest(unittest.TestCase):
|
|
@slow
|
|
def test_model_8b_logits(self):
|
|
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338]
|
|
model = JetMoeForCausalLM.from_pretrained("jetmoe/jetmoe-8b", device_map="auto")
|
|
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
|
|
with torch.no_grad():
|
|
out = model(input_ids).logits.cpu()
|
|
# Expected mean on dim = -1
|
|
EXPECTED_MEAN = torch.tensor([[0.2507, -2.7073, -1.3445, -1.9363, -1.7216, -1.7370, -1.9054, -1.9792]])
|
|
torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, atol=1e-2, rtol=1e-2)
|
|
# slicing logits[0, 0, 0:30]
|
|
EXPECTED_SLICE = torch.tensor([-3.3689, 5.9006, 5.7450, -1.7012, -4.7072, -4.7071, -4.7071, -4.7071, -4.7072, -4.7072, -4.7072, -4.7071, 3.8321, 9.1746, -4.7071, -4.7072, -4.7071, -4.7072, -4.7071, -4.7072, -4.7071, -4.7071, -4.7071, -4.7071, -4.7071, -4.7071, -4.7071, -4.7071, -4.7071, -4.7071]) # fmt: skip
|
|
torch.testing.assert_close(out[0, 0, :30], EXPECTED_SLICE, atol=1e-4, rtol=1e-4)
|
|
|
|
del model
|
|
backend_empty_cache(torch_device)
|
|
gc.collect()
|
|
|
|
@slow
|
|
def test_model_8b_generation(self):
|
|
EXPECTED_TEXT_COMPLETION = """My favourite condiment is ....\nI love ketchup. I love"""
|
|
prompt = "My favourite condiment is "
|
|
tokenizer = AutoTokenizer.from_pretrained("jetmoe/jetmoe-8b", use_fast=False)
|
|
model = JetMoeForCausalLM.from_pretrained("jetmoe/jetmoe-8b", device_map="auto")
|
|
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.model.embed_tokens.weight.device)
|
|
|
|
# greedy generation outputs
|
|
generated_ids = model.generate(input_ids, max_new_tokens=10, temperature=0)
|
|
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
|
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
|
|
|
|
del model
|
|
backend_empty_cache(torch_device)
|
|
gc.collect()
|
|
|
|
@slow
|
|
def test_model_8b_batched_generation(self):
|
|
EXPECTED_TEXT_COMPLETION = [
|
|
"""My favourite condiment is ....\nI love ketchup. I love""",
|
|
"""My favourite 2018 Christmas present was a new pair""",
|
|
]
|
|
prompt = [
|
|
"My favourite condiment is ",
|
|
"My favourite ",
|
|
]
|
|
tokenizer = AutoTokenizer.from_pretrained("jetmoe/jetmoe-8b", use_fast=False)
|
|
model = JetMoeForCausalLM.from_pretrained("jetmoe/jetmoe-8b", device_map="auto")
|
|
input_ids = tokenizer(prompt, return_tensors="pt", padding=True).to(model.model.embed_tokens.weight.device)
|
|
print(input_ids)
|
|
|
|
# greedy generation outputs
|
|
generated_ids = model.generate(**input_ids, max_new_tokens=10, temperature=0)
|
|
text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
|
print(text)
|
|
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
|
|
|
|
del model
|
|
backend_empty_cache(torch_device)
|
|
gc.collect()
|