465 lines
20 KiB
Python
465 lines
20 KiB
Python
# coding=utf-8
|
|
# Copyright 2023 Toshiyuki Sakamoto(tanreinama) and HuggingFace Inc. team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
|
|
import unittest
|
|
|
|
import numpy as np
|
|
|
|
from transformers import (
|
|
GPTSanJapaneseConfig,
|
|
GPTSanJapaneseForConditionalGeneration,
|
|
GPTSanJapaneseModel,
|
|
GPTSanJapaneseTokenizer,
|
|
is_torch_available,
|
|
)
|
|
from transformers.generation import GenerationConfig
|
|
from transformers.testing_utils import require_torch, slow, tooslow, torch_device
|
|
|
|
from ...generation.test_utils import GenerationTesterMixin
|
|
from ...test_configuration_common import ConfigTester
|
|
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
|
from ...test_pipeline_mixin import PipelineTesterMixin
|
|
|
|
|
|
class GPTSanJapaneseTester:
|
|
def __init__(
|
|
self,
|
|
parent,
|
|
vocab_size=99,
|
|
batch_size=13,
|
|
num_contexts=7,
|
|
# For common tests
|
|
is_training=True,
|
|
hidden_size=32,
|
|
ext_size=42,
|
|
num_hidden_layers=2,
|
|
num_ext_layers=2,
|
|
num_attention_heads=4,
|
|
num_experts=2,
|
|
d_ff=32,
|
|
d_ext=80,
|
|
d_spout=33,
|
|
dropout_rate=0.0,
|
|
layer_norm_epsilon=1e-6,
|
|
expert_capacity=100,
|
|
router_jitter_noise=0.0,
|
|
):
|
|
self.vocab_size = vocab_size
|
|
self.parent = parent
|
|
self.batch_size = batch_size
|
|
self.num_contexts = num_contexts
|
|
# For common tests
|
|
self.seq_length = self.num_contexts
|
|
self.is_training = is_training
|
|
self.hidden_size = hidden_size
|
|
self.num_ext_layers = num_ext_layers
|
|
self.ext_size = ext_size
|
|
self.num_hidden_layers = num_hidden_layers
|
|
self.num_attention_heads = num_attention_heads
|
|
self.num_experts = num_experts
|
|
self.d_ff = d_ff
|
|
self.d_ext = d_ext
|
|
self.d_spout = d_spout
|
|
self.dropout_rate = dropout_rate
|
|
self.layer_norm_epsilon = layer_norm_epsilon
|
|
self.expert_capacity = expert_capacity
|
|
self.router_jitter_noise = router_jitter_noise
|
|
|
|
def get_large_model_config(self):
|
|
return GPTSanJapaneseConfig.from_pretrained("Tanrei/GPTSAN-japanese")
|
|
|
|
def prepare_config_and_inputs(self):
|
|
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
|
|
|
config = self.get_config()
|
|
|
|
return (config, input_ids)
|
|
|
|
def prepare_config_and_inputs_for_common(self):
|
|
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
|
|
|
config = self.get_config()
|
|
|
|
return (config, {"input_ids": input_ids})
|
|
|
|
def get_config(self):
|
|
return GPTSanJapaneseConfig(
|
|
vocab_size=self.vocab_size,
|
|
num_contexts=self.seq_length,
|
|
d_model=self.hidden_size,
|
|
d_ff=self.d_ff,
|
|
d_ext=self.d_ext,
|
|
d_spout=self.d_spout,
|
|
num_switch_layers=self.num_hidden_layers - self.num_ext_layers,
|
|
num_ext_layers=self.num_ext_layers,
|
|
num_heads=self.num_attention_heads,
|
|
num_experts=self.num_experts,
|
|
expert_capacity=self.expert_capacity,
|
|
dropout_rate=self.dropout_rate,
|
|
layer_norm_epsilon=self.layer_norm_epsilon,
|
|
router_jitter_noise=self.router_jitter_noise,
|
|
)
|
|
|
|
def create_and_check_model(
|
|
self,
|
|
config,
|
|
input_ids,
|
|
):
|
|
model = GPTSanJapaneseForConditionalGeneration(config=config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
result = model(
|
|
input_ids=input_ids,
|
|
)
|
|
self.parent.assertIsNotNone(result)
|
|
|
|
|
|
@require_torch
|
|
class GPTSanJapaneseTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|
all_model_classes = (GPTSanJapaneseModel,) if is_torch_available() else ()
|
|
pipeline_model_mapping = (
|
|
{
|
|
"conversational": GPTSanJapaneseForConditionalGeneration,
|
|
"feature-extraction": GPTSanJapaneseForConditionalGeneration,
|
|
"summarization": GPTSanJapaneseForConditionalGeneration,
|
|
"text2text-generation": GPTSanJapaneseForConditionalGeneration,
|
|
"translation": GPTSanJapaneseForConditionalGeneration,
|
|
}
|
|
if is_torch_available()
|
|
else {}
|
|
)
|
|
fx_compatible = False
|
|
is_encoder_decoder = False
|
|
test_pruning = False
|
|
test_headmasking = False
|
|
test_cpu_offload = False
|
|
test_disk_offload = False
|
|
test_save_load_fast_init_to_base = False
|
|
test_training = False
|
|
# The small GPTSAN_JAPANESE model needs higher percentages for CPU/MP tests
|
|
model_split_percents = [0.8, 0.9]
|
|
|
|
# TODO: Fix the failed tests when this model gets more usage
|
|
def is_pipeline_test_to_skip(
|
|
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
|
):
|
|
if pipeline_test_casse_name == "SummarizationPipelineTests":
|
|
# TODO: fix `_reorder_cache` is not implemented for this model
|
|
return True
|
|
elif pipeline_test_casse_name == "Text2TextGenerationPipelineTests":
|
|
# TODO: check this.
|
|
return True
|
|
|
|
return False
|
|
|
|
def setUp(self):
|
|
self.model_tester = GPTSanJapaneseTester(self)
|
|
self.config_tester = ConfigTester(self, config_class=GPTSanJapaneseConfig, d_model=37)
|
|
|
|
def test_config(self):
|
|
GPTSanJapaneseConfig()
|
|
|
|
def test_model(self):
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
|
|
|
@unittest.skip(
|
|
reason="skip for now as the computed `max_memory` by `model_split_percents` in the test method will be changed inside `from_pretrained`"
|
|
)
|
|
def test_model_parallelism(self):
|
|
super().test_model_parallelism()
|
|
|
|
|
|
@require_torch
|
|
class GPTSanJapaneseForConditionalGenerationTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
|
all_model_classes = (GPTSanJapaneseForConditionalGeneration,) if is_torch_available() else ()
|
|
fx_compatible = False
|
|
is_encoder_decoder = False
|
|
test_pruning = False
|
|
test_headmasking = False
|
|
test_cpu_offload = False
|
|
test_disk_offload = False
|
|
# The small GPTSAN_JAPANESE model needs higher percentages for CPU/MP tests
|
|
model_split_percents = [0.8, 0.9]
|
|
|
|
def setUp(self):
|
|
self.model_tester = GPTSanJapaneseTester(self)
|
|
self.config_tester = ConfigTester(self, config_class=GPTSanJapaneseConfig, d_model=37)
|
|
|
|
def test_config(self):
|
|
GPTSanJapaneseConfig()
|
|
|
|
def test_model(self):
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
|
|
|
@unittest.skip(
|
|
reason="skip for now as the computed `max_memory` by `model_split_percents` in the test method will be changed inside `from_pretrained`"
|
|
)
|
|
def test_model_parallelism(self):
|
|
super().test_model_parallelism()
|
|
|
|
@slow
|
|
def test_logits(self):
|
|
model = GPTSanJapaneseForConditionalGeneration.from_pretrained("Tanrei/GPTSAN-japanese")
|
|
tokenizer = GPTSanJapaneseTokenizer.from_pretrained("Tanrei/GPTSAN-japanese")
|
|
input_ids = tokenizer.encode("武田信玄は", return_tensors="pt")
|
|
outputs = model(input_ids)
|
|
output_logits = outputs.logits.detach().cpu().numpy()
|
|
# Output of original model created with mesh-tensoflow
|
|
# fmt: off
|
|
target = [
|
|
[-12.037839889526367, -12.433061599731445, -14.333840370178223, -12.450345993041992, -11.1661376953125,
|
|
-11.930137634277344, -10.659740447998047, -12.909574508666992, -13.241043090820312, -13.398579597473145,
|
|
-11.107524871826172, -12.3685941696167, -22.97943115234375, -10.481067657470703, -12.484030723571777,
|
|
-12.807360649108887, -14.769700050354004, -12.233579635620117, -13.428145408630371, -22.624177932739258],
|
|
[-7.511149883270264, -8.281851768493652, -7.943127155303955, -7.55021333694458, -6.49869966506958,
|
|
-7.586796283721924, -6.978085994720459, -7.839145183563232, -8.21964168548584, -8.695091247558594,
|
|
-6.706910610198975, -6.6585798263549805, -19.565698623657227, -5.353842735290527, -8.350686073303223,
|
|
-8.039388656616211, -10.856569290161133, -7.75154447555542, -8.819022178649902, -19.51532745361328],
|
|
[-9.73066234588623, -10.223922729492188, -9.932981491088867, -11.857836723327637, -7.662626266479492,
|
|
-11.13529109954834, -7.765097618103027, -11.472923278808594, -9.543149948120117, -11.905633926391602,
|
|
-9.366164207458496, -11.5734281539917, -23.699003219604492, -9.429590225219727, -10.42839241027832,
|
|
-10.585240364074707, -10.94771957397461, -11.095416069030762, -10.390240669250488, -23.769372940063477],
|
|
[-9.728265762329102, -9.859712600708008, -10.09729290008545, -9.678522109985352, -6.879519939422607,
|
|
-9.68487548828125, -4.2803425788879395, -10.018914222717285, -9.308445930480957, -10.63394546508789,
|
|
-8.083646774291992, -9.06301498413086, -21.904266357421875, -8.90160846710205, -8.841876029968262,
|
|
-11.856719970703125, -12.079398155212402, -11.233753204345703, -10.177338600158691, -21.87256622314453],
|
|
[-9.669764518737793, -9.614198684692383, -9.814510345458984, -9.996501922607422, -11.375690460205078,
|
|
-10.113405227661133, -10.546867370605469, -10.04369068145752, -10.907809257507324, -10.504216194152832,
|
|
-11.129199028015137, -10.151124000549316, -21.96586799621582, -9.086349487304688, -11.730339050292969,
|
|
-10.460667610168457, -10.298049926757812, -10.784148216247559, -10.840693473815918, -22.03152847290039],
|
|
]
|
|
# fmt: on
|
|
target = np.array(target).flatten()
|
|
predict = output_logits[0, :, :20].flatten()
|
|
|
|
def check(a, b, epsilon=5e-4):
|
|
return abs(a - b) < epsilon * max(abs(a), abs(b))
|
|
|
|
self.assertTrue(np.all([check(target[i], predict[i]) for i in range(len(target))]))
|
|
|
|
@slow
|
|
def test_batch_generation(self):
|
|
model = GPTSanJapaneseForConditionalGeneration.from_pretrained("Tanrei/GPTSAN-japanese")
|
|
tokenizer = GPTSanJapaneseTokenizer.from_pretrained("Tanrei/GPTSAN-japanese")
|
|
model.to(torch_device)
|
|
|
|
# set deterministically
|
|
generation_config = GenerationConfig.from_pretrained("Tanrei/GPTSAN-japanese")
|
|
generation_config.top_k = 1
|
|
|
|
# use different length sentences to test batching
|
|
sentences = [
|
|
"甲斐なら武田と言うほど",
|
|
"織田信長は、",
|
|
]
|
|
|
|
tokenizer.padding_side = "left"
|
|
inputs = tokenizer(sentences, return_tensors="pt", padding=True)
|
|
input_ids = inputs["input_ids"].to(torch_device)
|
|
|
|
self.assertNotEqual(inputs["attention_mask"][0].numpy().tolist(), inputs["attention_mask"][1].numpy().tolist())
|
|
|
|
outputs = model.generate(
|
|
input_ids=input_ids,
|
|
attention_mask=inputs["attention_mask"].to(torch_device),
|
|
max_new_tokens=3,
|
|
generation_config=generation_config,
|
|
)
|
|
|
|
inputs_non_padded = tokenizer(sentences[0], return_tensors="pt").input_ids.to(torch_device)
|
|
output_non_padded = model.generate(
|
|
input_ids=inputs_non_padded, max_new_tokens=3, generation_config=generation_config
|
|
)
|
|
|
|
inputs_padded = tokenizer(sentences[1], return_tensors="pt").input_ids.to(torch_device)
|
|
output_padded = model.generate(input_ids=inputs_padded, max_new_tokens=3, generation_config=generation_config)
|
|
|
|
self.assertNotEqual(inputs_non_padded.shape, inputs_padded.shape)
|
|
|
|
batch_out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
|
non_padded_sentence = tokenizer.decode(output_non_padded[0], skip_special_tokens=True)
|
|
padded_sentence = tokenizer.decode(output_padded[0], skip_special_tokens=True)
|
|
|
|
expected_output_sentence = [
|
|
"甲斐なら武田と言うほど甲斐の武田",
|
|
"織田信長は、このような",
|
|
]
|
|
self.assertListEqual(expected_output_sentence, batch_out_sentence)
|
|
self.assertListEqual(batch_out_sentence, [non_padded_sentence, padded_sentence])
|
|
|
|
@tooslow
|
|
def test_sample(self):
|
|
model = GPTSanJapaneseForConditionalGeneration.from_pretrained("Tanrei/GPTSAN-japanese")
|
|
tokenizer = GPTSanJapaneseTokenizer.from_pretrained("Tanrei/GPTSAN-japanese")
|
|
# Output of original model created with mesh-tensoflow
|
|
target = [
|
|
("武田信玄は", 35675),
|
|
("武田信玄は、", 45),
|
|
("武田信玄は、この", 29),
|
|
("武田信玄は、このよう", 30642),
|
|
("武田信玄は、このような", 35680),
|
|
("武田信玄は、このような「", 8640),
|
|
("武田信玄は、このような「武田", 31617),
|
|
("武田信玄は、このような「武田家", 30646),
|
|
("武田信玄は、このような「武田家の", 31617),
|
|
("武田信玄は、このような「武田家の家", 31381),
|
|
]
|
|
for input, output in target:
|
|
input_ids = tokenizer.encode(input, return_tensors="pt")
|
|
outputs = model(input_ids)
|
|
output_logits = outputs.logits.detach().cpu().numpy()[0]
|
|
output_id = np.argmax(output_logits[-1])
|
|
self.assertEqual(output_id, output)
|
|
|
|
@slow
|
|
def test_spout_generation(self):
|
|
model = GPTSanJapaneseForConditionalGeneration.from_pretrained("Tanrei/GPTSAN-japanese")
|
|
tokenizer = GPTSanJapaneseTokenizer.from_pretrained("Tanrei/GPTSAN-japanese")
|
|
model.to(torch_device)
|
|
|
|
# set deterministically
|
|
generation_config = GenerationConfig.from_pretrained("Tanrei/GPTSAN-japanese")
|
|
generation_config.top_k = 1
|
|
|
|
input_text = "武田信玄は、"
|
|
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(torch_device)
|
|
input_ids_batch = tokenizer([input_text, input_text], return_tensors="pt").input_ids.to(torch_device)
|
|
|
|
# spout from uniform and one-hot
|
|
|
|
spouts = [
|
|
[0.87882208, 0.38426396, 0.33220248, 0.43890406, 0.16562252,
|
|
0.04803985, 0.211572 , 0.23188473, 0.37153068, 0.7836377 ,
|
|
0.02160172, 0.38761719, 0.75290772, 0.90198857, 0.34365777,
|
|
0.64168169, 0.44318471, 0.14575746, 0.92562881, 0.40812148,
|
|
0.29019122, 0.88861599, 0.65524846, 0.43563456, 0.38177187,
|
|
0.70832965, 0.81527892, 0.68832812, 0.38833192, 0.4561522 ,
|
|
0.14828817, 0.47248213, 0.54357335, 0.82009566, 0.1338884 ,
|
|
0.02755417, 0.19764677, 0.2422084 , 0.04757674, 0.65409606,
|
|
0.0824589 , 0.03304383, 0.94387689, 0.98764509, 0.82433901,
|
|
0.27646741, 0.64907493, 0.76009406, 0.30087915, 0.17904689,
|
|
0.41601714, 0.67046398, 0.10422822, 0.08447374, 0.07354344,
|
|
0.61423565, 0.70284866, 0.7532333 , 0.1972038 , 0.29575659,
|
|
0.90583886, 0.29265307, 0.50000175, 0.70407655, 0.889363 ,
|
|
0.81904418, 0.66829128, 0.64468815, 0.56563723, 0.85601875,
|
|
0.94924672, 0.00166762, 0.25220643, 0.74540219, 0.67993247,
|
|
0.1549675 , 0.39385352, 0.92153607, 0.63745931, 0.27759043,
|
|
0.84702295, 0.65904271, 0.58676614, 0.8666936 , 0.39607438,
|
|
0.79954983, 0.42220697, 0.39650381, 0.7849864 , 0.56150201,
|
|
0.15678925, 0.14746032, 0.34542114, 0.47026783, 0.11956489,
|
|
0.25421435, 0.33788901, 0.68934842, 0.36424685, 0.71737898,
|
|
0.38983449, 0.94393779, 0.39575588, 0.36616553, 0.87104665,
|
|
0.64630203, 0.22516905, 0.88270804, 0.15031338, 0.75144345,
|
|
0.46459025, 0.85396454, 0.86355643, 0.65139851, 0.70266061,
|
|
0.30241389, 0.81056497, 0.88865969, 0.38773807, 0.70635849,
|
|
0.90718459, 0.43245789, 0.28000654, 0.45935562, 0.08773519,
|
|
0.9552151 , 0.93901511, 0.22489288], # uniform
|
|
[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
|
|
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
|
|
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
|
|
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
|
|
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
|
|
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
|
|
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
|
|
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
|
|
0., 0., 0., 0., 0., 0., 0., 0.],
|
|
] # fmt: skip
|
|
|
|
output1 = model.generate(
|
|
input_ids=input_ids,
|
|
spout=spouts[0],
|
|
max_new_tokens=20,
|
|
generation_config=generation_config,
|
|
)
|
|
|
|
output2 = model.generate(
|
|
input_ids=input_ids,
|
|
spout=spouts[1],
|
|
max_new_tokens=20,
|
|
generation_config=generation_config,
|
|
)
|
|
|
|
output3 = model.generate(
|
|
input_ids=input_ids_batch,
|
|
spout=spouts,
|
|
max_new_tokens=20,
|
|
generation_config=generation_config,
|
|
)
|
|
|
|
out1_sentence = tokenizer.decode(output1[0])
|
|
out2_sentence = tokenizer.decode(output2[0])
|
|
batch_out_sentence = tokenizer.batch_decode(output3)
|
|
|
|
expected_output_sentence = [
|
|
"武田信玄は、武田氏の滅亡後、武田氏の居城であった甲斐武田氏の居城である",
|
|
"武田信玄は、武田家の滅亡を防ぐため、武田家の家臣である武田信虎を討",
|
|
]
|
|
self.assertListEqual(expected_output_sentence, batch_out_sentence)
|
|
self.assertListEqual(batch_out_sentence, [out1_sentence, out2_sentence])
|
|
|
|
@slow
|
|
def test_prefix_lm_generation(self):
|
|
model = GPTSanJapaneseForConditionalGeneration.from_pretrained("Tanrei/GPTSAN-japanese")
|
|
tokenizer = GPTSanJapaneseTokenizer.from_pretrained("Tanrei/GPTSAN-japanese")
|
|
model.to(torch_device)
|
|
|
|
# set deterministically
|
|
generation_config = GenerationConfig.from_pretrained("Tanrei/GPTSAN-japanese")
|
|
generation_config.top_k = 1
|
|
|
|
prefix_text_1 = "武田信玄"
|
|
prefix_text_2 = "織田信長"
|
|
input_text_1 = "は、"
|
|
input_text_2 = "が、"
|
|
input_tok_1 = tokenizer(input_text_1, prefix_text=prefix_text_1, return_tensors="pt")
|
|
input_tok_2 = tokenizer(input_text_2, prefix_text=prefix_text_2, return_tensors="pt")
|
|
input_tok_3 = tokenizer([[prefix_text_1, input_text_1], [prefix_text_2, input_text_2]], return_tensors="pt")
|
|
|
|
output1 = model.generate(
|
|
input_ids=input_tok_1.input_ids.to(torch_device),
|
|
token_type_ids=input_tok_1.token_type_ids.to(torch_device),
|
|
max_new_tokens=20,
|
|
generation_config=generation_config,
|
|
)
|
|
|
|
output2 = model.generate(
|
|
input_ids=input_tok_2.input_ids.to(torch_device),
|
|
token_type_ids=input_tok_2.token_type_ids.to(torch_device),
|
|
max_new_tokens=20,
|
|
generation_config=generation_config,
|
|
)
|
|
|
|
output3 = model.generate(
|
|
input_ids=input_tok_3.input_ids.to(torch_device),
|
|
token_type_ids=input_tok_3.token_type_ids.to(torch_device),
|
|
attention_mask=input_tok_3.attention_mask.to(torch_device),
|
|
max_new_tokens=20,
|
|
generation_config=generation_config,
|
|
)
|
|
|
|
out1_sentence = tokenizer.decode(output1[0])
|
|
out2_sentence = tokenizer.decode(output2[0])
|
|
batch_out_sentence = tokenizer.batch_decode(output3)
|
|
|
|
expected_output_sentence = [
|
|
"武田信玄は、武田氏の祖である武田信虎を、その子・武田信友を擁して",
|
|
"織田信長が、織田信長の妻・お市の方を妻として迎えたという逸話が残",
|
|
]
|
|
self.assertListEqual(expected_output_sentence, batch_out_sentence)
|
|
self.assertListEqual(batch_out_sentence, [out1_sentence, out2_sentence])
|