570 lines
33 KiB
Python
570 lines
33 KiB
Python
# coding=utf-8
|
|
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
""" Testing suite for the PyTorch NLLB-MoE model. """
|
|
|
|
|
|
import copy
|
|
import tempfile
|
|
import unittest
|
|
|
|
from transformers import NllbMoeConfig, is_torch_available, set_seed
|
|
from transformers.testing_utils import (
|
|
require_sentencepiece,
|
|
require_tokenizers,
|
|
require_torch,
|
|
require_torch_fp16,
|
|
slow,
|
|
torch_device,
|
|
)
|
|
from transformers.utils import cached_property
|
|
|
|
from ...generation.test_utils import GenerationTesterMixin
|
|
from ...test_configuration_common import ConfigTester
|
|
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
|
from ...test_pipeline_mixin import PipelineTesterMixin
|
|
|
|
|
|
if is_torch_available():
|
|
import torch
|
|
|
|
from transformers import NllbMoeForConditionalGeneration, NllbMoeModel, NllbTokenizer
|
|
from transformers.models.nllb_moe.modeling_nllb_moe import NllbMoeDecoder, NllbMoeEncoder, NllbMoeTop2Router
|
|
|
|
|
|
class NllbMoeModelTester:
|
|
def __init__(
|
|
self,
|
|
parent,
|
|
batch_size=13,
|
|
seq_length=7,
|
|
is_training=True,
|
|
use_labels=False,
|
|
vocab_size=99,
|
|
hidden_size=16,
|
|
num_hidden_layers=2,
|
|
num_attention_heads=4,
|
|
intermediate_size=4,
|
|
hidden_act="relu",
|
|
hidden_dropout_prob=0.1,
|
|
attention_probs_dropout_prob=0.1,
|
|
encoder_layerdrop=0.0,
|
|
decoder_layerdrop=0.0,
|
|
max_position_embeddings=20,
|
|
eos_token_id=2,
|
|
pad_token_id=1,
|
|
bos_token_id=0,
|
|
num_experts=4,
|
|
encoder_sparse_step=2,
|
|
decoder_sparse_step=1,
|
|
expert_capacity=100,
|
|
router_jitter_noise=0.0,
|
|
):
|
|
self.parent = parent
|
|
self.batch_size = batch_size
|
|
self.seq_length = seq_length
|
|
self.is_training = is_training
|
|
self.use_labels = use_labels
|
|
self.vocab_size = vocab_size
|
|
self.hidden_size = hidden_size
|
|
self.num_hidden_layers = num_hidden_layers
|
|
self.num_attention_heads = num_attention_heads
|
|
self.intermediate_size = intermediate_size
|
|
self.hidden_act = hidden_act
|
|
self.hidden_dropout_prob = hidden_dropout_prob
|
|
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
|
self.encoder_layerdrop = encoder_layerdrop
|
|
self.decoder_layerdrop = decoder_layerdrop
|
|
self.max_position_embeddings = max_position_embeddings
|
|
self.eos_token_id = eos_token_id
|
|
self.pad_token_id = pad_token_id
|
|
self.bos_token_id = bos_token_id
|
|
self.encoder_sparse_step = encoder_sparse_step
|
|
self.decoder_sparse_step = decoder_sparse_step
|
|
self.expert_capacity = expert_capacity
|
|
self.router_jitter_noise = router_jitter_noise
|
|
self.num_experts = num_experts
|
|
|
|
def prepare_nllb_moe_inputs_dict(
|
|
self,
|
|
config,
|
|
input_ids,
|
|
decoder_input_ids,
|
|
attention_mask=None,
|
|
decoder_attention_mask=None,
|
|
head_mask=None,
|
|
decoder_head_mask=None,
|
|
cross_attn_head_mask=None,
|
|
):
|
|
if attention_mask is None:
|
|
attention_mask = input_ids.ne(config.pad_token_id)
|
|
if decoder_attention_mask is None:
|
|
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
|
|
if head_mask is None:
|
|
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
|
|
if decoder_head_mask is None:
|
|
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
|
|
if cross_attn_head_mask is None:
|
|
cross_attn_head_mask = torch.ones(
|
|
config.decoder_layers, config.decoder_attention_heads, device=torch_device
|
|
)
|
|
return {
|
|
"input_ids": input_ids,
|
|
"decoder_input_ids": decoder_input_ids,
|
|
"attention_mask": attention_mask,
|
|
"decoder_attention_mask": attention_mask,
|
|
"head_mask": head_mask,
|
|
"decoder_head_mask": decoder_head_mask,
|
|
"cross_attn_head_mask": cross_attn_head_mask,
|
|
}
|
|
|
|
def prepare_config_and_inputs(self):
|
|
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
|
input_ids[:, -1] = self.eos_token_id # Eos Token
|
|
decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
|
|
|
# we need to clamp the input ids here to avoid having pad token in between
|
|
# this is because for NllbMoe the position_ids are prepared such that
|
|
# all pad tokens have pos id = 2 and rest are between 2..seq_length
|
|
# and the seq_length here is seq_length - num_pad_tokens
|
|
# but when using past, there is no way of knowing if the past input ids had
|
|
# pad tokens in them, which results in incorrect seq_lenth and which in turn results in
|
|
# position_ids being off by num_pad_tokens in past input
|
|
input_ids = input_ids.clamp(self.pad_token_id + 1)
|
|
decoder_input_ids = decoder_input_ids.clamp(self.pad_token_id + 1)
|
|
|
|
config = self.get_config()
|
|
inputs_dict = self.prepare_nllb_moe_inputs_dict(config, input_ids, decoder_input_ids)
|
|
return config, inputs_dict
|
|
|
|
def get_config(self):
|
|
return NllbMoeConfig(
|
|
vocab_size=self.vocab_size,
|
|
d_model=self.hidden_size,
|
|
encoder_layers=self.num_hidden_layers,
|
|
decoder_layers=self.num_hidden_layers,
|
|
encoder_attention_heads=self.num_attention_heads,
|
|
decoder_attention_heads=self.num_attention_heads,
|
|
encoder_ffn_dim=self.intermediate_size,
|
|
decoder_ffn_dim=self.intermediate_size,
|
|
dropout=self.hidden_dropout_prob,
|
|
attention_dropout=self.attention_probs_dropout_prob,
|
|
encoder_layerdrop=self.encoder_layerdrop,
|
|
decoder_layerdrop=self.decoder_layerdrop,
|
|
max_position_embeddings=self.max_position_embeddings,
|
|
eos_token_id=self.eos_token_id,
|
|
bos_token_id=self.bos_token_id,
|
|
pad_token_id=self.pad_token_id,
|
|
expert_capacity=self.expert_capacity,
|
|
router_jitter_noise=self.router_jitter_noise,
|
|
decoder_sparse_step=self.decoder_sparse_step,
|
|
encoder_sparse_step=self.encoder_sparse_step,
|
|
num_experts=self.num_experts,
|
|
)
|
|
|
|
def prepare_config_and_inputs_for_common(self):
|
|
config, inputs_dict = self.prepare_config_and_inputs()
|
|
return config, inputs_dict
|
|
|
|
@require_torch
|
|
def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict):
|
|
model = NllbMoeModel(config=config).get_decoder().to(torch_device).eval()
|
|
input_ids = inputs_dict["input_ids"]
|
|
attention_mask = inputs_dict["attention_mask"]
|
|
head_mask = inputs_dict["head_mask"]
|
|
|
|
# first forward pass
|
|
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
|
|
|
|
output, past_key_values = outputs.to_tuple()
|
|
|
|
# create hypothetical multiple next token and extent to next_input_ids
|
|
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
|
|
next_attn_mask = ids_tensor((self.batch_size, 3), 2)
|
|
|
|
# append to next input_ids and
|
|
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
|
next_attention_mask = torch.cat([attention_mask, next_attn_mask], dim=-1)
|
|
|
|
output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)["last_hidden_state"]
|
|
output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)[
|
|
"last_hidden_state"
|
|
]
|
|
|
|
# select random slice
|
|
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
|
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
|
|
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
|
|
|
|
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
|
|
|
|
# test that outputs are equal for slice
|
|
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
|
|
|
def check_encoder_decoder_model_standalone(self, config, inputs_dict):
|
|
model = NllbMoeModel(config=config).to(torch_device).eval()
|
|
outputs = model(**inputs_dict)
|
|
|
|
encoder_last_hidden_state = outputs.encoder_last_hidden_state
|
|
last_hidden_state = outputs.last_hidden_state
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
encoder = model.get_encoder()
|
|
encoder.save_pretrained(tmpdirname)
|
|
encoder = NllbMoeEncoder.from_pretrained(tmpdirname).to(torch_device)
|
|
|
|
encoder_last_hidden_state_2 = encoder(inputs_dict["input_ids"], attention_mask=inputs_dict["attention_mask"])[
|
|
0
|
|
]
|
|
|
|
self.parent.assertTrue((encoder_last_hidden_state_2 - encoder_last_hidden_state).abs().max().item() < 1e-3)
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
decoder = model.get_decoder()
|
|
decoder.save_pretrained(tmpdirname)
|
|
decoder = NllbMoeDecoder.from_pretrained(tmpdirname).to(torch_device)
|
|
|
|
last_hidden_state_2 = decoder(
|
|
input_ids=inputs_dict["decoder_input_ids"],
|
|
attention_mask=inputs_dict["decoder_attention_mask"],
|
|
encoder_hidden_states=encoder_last_hidden_state,
|
|
encoder_attention_mask=inputs_dict["attention_mask"],
|
|
)[0]
|
|
|
|
self.parent.assertTrue((last_hidden_state_2 - last_hidden_state).abs().max().item() < 1e-3)
|
|
|
|
|
|
@require_torch
|
|
class NllbMoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|
all_model_classes = (NllbMoeModel, NllbMoeForConditionalGeneration) if is_torch_available() else ()
|
|
all_generative_model_classes = (NllbMoeForConditionalGeneration,) if is_torch_available() else ()
|
|
pipeline_model_mapping = (
|
|
{
|
|
"conversational": NllbMoeForConditionalGeneration,
|
|
"feature-extraction": NllbMoeModel,
|
|
"summarization": NllbMoeForConditionalGeneration,
|
|
"text2text-generation": NllbMoeForConditionalGeneration,
|
|
"translation": NllbMoeForConditionalGeneration,
|
|
}
|
|
if is_torch_available()
|
|
else {}
|
|
)
|
|
is_encoder_decoder = True
|
|
fx_compatible = False
|
|
test_pruning = False
|
|
test_missing_keys = True
|
|
test_torchscript = False
|
|
|
|
# TODO: Fix the failed tests when this model gets more usage
|
|
def is_pipeline_test_to_skip(
|
|
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
|
):
|
|
# Saving the slow tokenizer after saving the fast tokenizer causes the loading of the later hanging forever.
|
|
return True
|
|
|
|
def setUp(self):
|
|
self.model_tester = NllbMoeModelTester(self)
|
|
self.config_tester = ConfigTester(self, config_class=NllbMoeConfig)
|
|
|
|
def test_config(self):
|
|
self.config_tester.run_common_tests()
|
|
|
|
def test_save_load_strict(self):
|
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
|
|
for model_class in self.all_model_classes:
|
|
model = model_class(config)
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
model.save_pretrained(tmpdirname)
|
|
model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
|
|
self.assertEqual(info["missing_keys"], [])
|
|
|
|
def test_decoder_model_past_with_large_inputs(self):
|
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
|
|
config.decoder_sparse_step = 0
|
|
self.model_tester.create_and_check_decoder_model_past_large_inputs(config, inputs_dict)
|
|
|
|
def test_encoder_decoder_model_standalone(self):
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
|
self.model_tester.check_encoder_decoder_model_standalone(*config_and_inputs)
|
|
|
|
def test_inputs_embeds(self):
|
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
|
|
for model_class in (NllbMoeModel, NllbMoeForConditionalGeneration):
|
|
model = model_class(config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
|
|
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
|
|
|
|
if not self.is_encoder_decoder:
|
|
input_ids = inputs["input_ids"]
|
|
del inputs["input_ids"]
|
|
else:
|
|
encoder_input_ids = inputs["input_ids"]
|
|
decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids)
|
|
del inputs["input_ids"]
|
|
inputs.pop("decoder_input_ids", None)
|
|
|
|
wte = model.get_input_embeddings()
|
|
if not self.is_encoder_decoder:
|
|
inputs["inputs_embeds"] = wte(input_ids)
|
|
else:
|
|
inputs["inputs_embeds"] = wte(encoder_input_ids)
|
|
inputs["decoder_inputs_embeds"] = wte(decoder_input_ids)
|
|
|
|
with torch.no_grad():
|
|
model(**inputs)[0]
|
|
|
|
@require_torch_fp16
|
|
def test_generate_fp16(self):
|
|
config, input_dict = self.model_tester.prepare_config_and_inputs()
|
|
input_ids = input_dict["input_ids"]
|
|
attention_mask = input_ids.ne(1).to(torch_device)
|
|
model = NllbMoeForConditionalGeneration(config).eval().to(torch_device)
|
|
model.half()
|
|
model.generate(input_ids, attention_mask=attention_mask)
|
|
model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)
|
|
|
|
def test_get_loss(self):
|
|
config, input_dict = self.model_tester.prepare_config_and_inputs()
|
|
input_dict["output_router_logits"] = True
|
|
input_dict["labels"] = input_dict["input_ids"]
|
|
model = NllbMoeForConditionalGeneration(config).eval().to(torch_device)
|
|
out = model(**input_dict)
|
|
self.assertIsNotNone(out.loss)
|
|
self.assertIsNotNone(model(**input_dict)["encoder_router_logits"][1])
|
|
self.assertIsNotNone(model(**input_dict)["decoder_router_logits"][0])
|
|
|
|
|
|
@require_torch
|
|
@require_sentencepiece
|
|
@require_tokenizers
|
|
@slow
|
|
class NllbMoeModelIntegrationTests(unittest.TestCase):
|
|
@require_torch
|
|
@cached_property
|
|
def model_inputs(self):
|
|
return {
|
|
"input_ids": torch.LongTensor(
|
|
[
|
|
[28768, 248, 6399, 9, 65972, 452, 1925, 629, 123543, 248075, 2, 256047],
|
|
[117, 7027, 7195, 202, 44778, 248075, 2, 256047, 1, 1, 1, 1],
|
|
]
|
|
),
|
|
"attention_mask": torch.Tensor(
|
|
[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0]]
|
|
),
|
|
"decoder_input_ids": torch.LongTensor([[2, 256057], [2, 256057]]),
|
|
}
|
|
|
|
@cached_property
|
|
def tokenizer(self):
|
|
return NllbTokenizer.from_pretrained("hf-internal-testing/random-nllb-moe-2-experts")
|
|
|
|
@cached_property
|
|
def big_model(self):
|
|
return NllbMoeForConditionalGeneration.from_pretrained("facebook/nllb-moe-54b")
|
|
|
|
def inference_no_head(self):
|
|
model = NllbMoeModel.from_pretrained("hf-internal-testing/random-nllb-moe-2-experts").eval()
|
|
with torch.no_grad():
|
|
output = model(**self.model_inputs)
|
|
# fmt: off
|
|
EXPECTED_ENCODER_STATE = torch.Tensor([ 0.3920, -0.1974, -0.0279, 0.3463, -0.8306, -1.0629, -0.4643, 2.0563, 1.1123, 0.3566, -0.9291, -0.3840, -0.2527, -0.9858, 1.5185, -1.1346, 0.0323, -0.9103, -0.3647, -0.4462, -0.9720, -0.3541, 0.1777, -0.4647, 1.6970, -0.9062, 0.2727, -1.0737, 0.8785, 0.4324])
|
|
EXPECTED_DECODER_STATE = torch.Tensor([-6.0425e-02, -2.0015e-01, 6.0575e-02, -8.6366e-01, -1.1310e+00, 6.8369e-01, 7.5615e-01, 7.3555e-01, 2.3071e-01, 1.5954e+00, -7.0728e-01, -2.2647e-01, -1.3292e+00, 4.8246e-01, -6.9153e-01, -1.8199e-02, -7.3664e-01, 1.5902e-03, 1.0760e-01, 1.0298e-01, -9.3933e-01, -4.6567e-01, 8.0417e-01, 1.5243e+00, 5.5844e-01, -9.9239e-02, 1.4885e+00, 7.1527e-02, -5.2612e-01, 9.4435e-02])
|
|
# fmt: on
|
|
|
|
torch.testing.assert_close(
|
|
output.encoder_last_hidden_state[1, 0, :30], EXPECTED_ENCODER_STATE, rtol=6e-3, atol=9e-3
|
|
)
|
|
torch.testing.assert_close(output.last_hidden_state[1, 0, :30], EXPECTED_DECODER_STATE, rtol=6e-3, atol=9e-3)
|
|
|
|
def test_inference_logits(self):
|
|
r"""
|
|
Logits testing to check implementation consistency between `fairseq` implementation
|
|
and `transformers` implementation of NLLB-MoE transformers. We only check the logits
|
|
of the second sample of the batch, as it is padded.
|
|
"""
|
|
model = NllbMoeForConditionalGeneration.from_pretrained("hf-internal-testing/random-nllb-moe-2-experts").eval()
|
|
with torch.no_grad():
|
|
output = model(**self.model_inputs)
|
|
|
|
EXPECTED_LOGTIS = torch.Tensor([-0.3059, 0.0000, 9.3029, 0.6456, -0.9148, 1.7836, 0.6478, 0.9438, -0.5272, -0.6617, -1.2717, 0.4564, 0.1345, -0.2301, -1.0140, 1.1427, -1.5535, 0.1337, 0.2082, -0.8112, -0.3842, -0.3377, 0.1256, 0.6450, -0.0452, 0.0219, 1.4274, -0.4991, -0.2063, -0.4409,]) # fmt: skip
|
|
torch.testing.assert_close(output.logits[1, 0, :30], EXPECTED_LOGTIS, rtol=6e-3, atol=9e-3)
|
|
|
|
@unittest.skip("This requires 300GB of RAM")
|
|
def test_large_logits(self):
|
|
model = self.big_model
|
|
with torch.no_grad():
|
|
output = model(**self.model_inputs)
|
|
|
|
# fmt: off
|
|
EXPECTED_ENCODER_STATE = torch.Tensor([ 0.1696, -0.0059, 0.0489, 0.0479, -0.4222, -0.2178, -0.1372, -0.0860, -0.4249, -0.0081, -0.1186, 0.6678, 0.0160, 0.4140, 0.1799, 0.0672, -0.4941, 0.0173, -0.0740, 0.0845, -0.2197, 0.4465, 0.2268, -0.1752, -0.0562, 0.1033, -0.0869, -0.5490, 0.0582, 0.2165])
|
|
EXPECTED_DECODER_STATE = torch.Tensor([ 0.0374, -0.1055, -0.1060, -0.1711, -0.0540, -0.1183, -0.0779, 0.0610, -0.0279, -0.0848, 0.0222, 0.0372, -0.0298, -0.0861, -0.0354, -0.0103, 0.0538, -0.0148, -0.0105, 0.0224, 0.0629, -0.0291, -0.0671, 0.0173, -0.0066, -0.0245, -0.0499, 0.0760, -0.0067, 0.0086])
|
|
EXPECTED_LOGTIS = torch.Tensor([ 0.3834, 0.2057, 4.5399, 0.8301, 0.4810, 0.9325, 0.9928, 0.9574, 0.5517, 0.9156, 0.2698, 0.6728, 0.7121, 0.3080, 0.4693, 0.5756, 1.0407, 0.2219, 0.3714, 0.5699, 0.5547, 0.8472, 0.3178, 0.1286, 0.1791, 0.9391, 0.5153, -0.2146, 0.1689, 0.6816])
|
|
# fmt: on
|
|
|
|
torch.testing.assert_close(
|
|
output.encoder_last_hidden_state[1, 0, :30], EXPECTED_ENCODER_STATE, rtol=6e-3, atol=9e-3
|
|
)
|
|
torch.testing.assert_close(output.last_hidden_state[1, 0, :30], EXPECTED_DECODER_STATE, rtol=6e-3, atol=9e-3)
|
|
torch.testing.assert_close(output.logits[1, 0, :30], EXPECTED_LOGTIS, rtol=6e-3, atol=9e-3)
|
|
|
|
@unittest.skip("This requires 300GB of RAM")
|
|
def test_seq_to_seq_generation(self):
|
|
model = self.big_model
|
|
tokenizer = NllbTokenizer.from_pretrained("facebook/nllb-moe-54b")
|
|
|
|
# first 6 samples of load_dataset("facebook/flores", "eng_Latn-fra_Latn"), devtest. Truth are very similar to the fairseq translation files
|
|
FIRST_6_FLORES_200 = [
|
|
'We now have 4-month-old mice that are non-diabetic that used to be diabetic," he added.',
|
|
"Dr. Ehud Ur, professor of medicine at Dalhousie University in Halifax, Nova Scotia and chair of the clinical and scientific division of the Canadian Diabetes Association cautioned that the research is still in its early days.",
|
|
"Like some other experts, he is skeptical about whether diabetes can be cured, noting that these findings have no relevance to people who already have Type 1 diabetes.",
|
|
"On Monday, Sara Danius, permanent secretary of the Nobel Committee for Literature at the Swedish Academy, publicly announced during a radio program on Sveriges Radio in Sweden the committee, unable to reach Bob Dylan directly about winning the 2016 Nobel Prize in Literature, had abandoned its efforts to reach him.",
|
|
'Danius said, "Right now we are doing nothing. I have called and sent emails to his closest collaborator and received very friendly replies. For now, that is certainly enough."',
|
|
"Previously, Ring's CEO, Jamie Siminoff, remarked the company started when his doorbell wasn't audible from his shop in his garage.",
|
|
]
|
|
inputs = tokenizer(FIRST_6_FLORES_200, padding=True, return_tensors="pt").to(torch_device)
|
|
batch_translation = model.generate(**inputs, forced_bos_token_id=tokenizer.lang_code_to_id["fra_Latn"])
|
|
|
|
EXPECTED_FAIRSEQ_TRANSLATION = [
|
|
'"Nous avons maintenant des souris de 4 mois non diabétiques qui étaient diabétiques", a-t-il ajouté.',
|
|
"Le docteur Ehud Ur, professeur de médecine à l'université Dalhousie, à Halifax, en Nouvelle-Écosse, et président de la division clinique et scientifique de l'Association canadienne du diabète, prévient que la recherche n'en est qu'à ses débuts.",
|
|
"Comme d'autres spécialistes, il est sceptique quant à la guérison du diabète.",
|
|
"Lundi, Sara Danius, secrétaire permanente du Comité Nobel de littérature à l'Académie suédoise, a annoncé publiquement lors d'une émission de radio sur Sveriges Radio en Suède que le comité, incapable de joindre Bob Dylan directement pour lui annoncer le prix Nobel de littérature 2016, avait abandonné ses efforts pour le joindre.",
|
|
"Danius a déclaré: \"Pour l'instant, nous ne faisons rien. J'ai appelé et envoyé des courriels à son plus proche collaborateur et j'ai reçu des réponses très amicales. Pour l'instant, c'est certainement suffisant\".",
|
|
"Auparavant, le PDG de Ring, Jamie Siminoff, a fait remarquer que la société avait commencé lorsque sa sonnette n'était pas audible depuis son magasin dans son garage.",
|
|
]
|
|
|
|
translation = tokenizer.batch_decode(
|
|
batch_translation.tolist(), clean_up_tokenization_spaces=True, skip_special_tokens=True
|
|
)
|
|
assert translation == EXPECTED_FAIRSEQ_TRANSLATION
|
|
|
|
|
|
@require_torch
|
|
class NllbMoeRouterTest(unittest.TestCase):
|
|
r"""
|
|
Switch Transformers has different blocks from classic transformer based models.
|
|
The Swift MLP contains a Router class, that has to be tested to check if it is correctly implemented
|
|
|
|
Original implementation of the routers here:
|
|
|
|
"""
|
|
|
|
config = NllbMoeConfig(
|
|
num_experts=4,
|
|
hidden_size=32,
|
|
d_ff=16,
|
|
expert_capacity=4,
|
|
)
|
|
batch_size = 2
|
|
sequence_length = 20
|
|
|
|
def test_top_2_routing(self):
|
|
# test routing with minimal reproduction
|
|
mask = torch.ones((self.batch_size, self.sequence_length), dtype=torch.bool)
|
|
mask[0][0] = False
|
|
mask[1][0] = False
|
|
mask = mask.reshape(-1)
|
|
set_seed(0)
|
|
hidden_states = torch.rand((self.batch_size, self.sequence_length, self.config.hidden_size))
|
|
classfier = torch.nn.Linear(self.config.hidden_size, self.config.num_experts)
|
|
hf_router = NllbMoeTop2Router(self.config)
|
|
|
|
_, _, hidden_dim = hidden_states.shape
|
|
logits = classfier(hidden_states.reshape((self.batch_size * self.sequence_length), hidden_dim))
|
|
top_1_mask, router_probs = hf_router.route_tokens(logits, padding_mask=mask)
|
|
torch.argmax(top_1_mask, dim=-1)
|
|
router_mask = router_probs.bool()
|
|
set_seed(0)
|
|
experts = [
|
|
torch.nn.Linear(hidden_dim, hidden_dim),
|
|
torch.nn.Linear(hidden_dim, hidden_dim),
|
|
torch.nn.Linear(hidden_dim, hidden_dim),
|
|
torch.nn.Linear(hidden_dim, hidden_dim),
|
|
]
|
|
hidden_states = hidden_states.reshape((self.batch_size * self.sequence_length), hidden_dim)
|
|
masked_hidden_states = torch.einsum("bm,be->ebm", hidden_states, router_mask)
|
|
for idx, expert in enumerate(experts):
|
|
token_indices = router_mask[:, idx]
|
|
combining_weights = router_probs[token_indices, idx]
|
|
expert_output = expert(masked_hidden_states[idx, token_indices])
|
|
expert_output *= 1 - self.config.moe_token_dropout
|
|
masked_hidden_states[idx, token_indices] = torch.einsum("b,be->be", combining_weights, expert_output)
|
|
hidden_states = masked_hidden_states.sum(dim=0).reshape(self.batch_size, self.sequence_length, hidden_dim)
|
|
|
|
EXPECTED_MEAN_FAIRSEQ_HIDDEN_STATES = torch.Tensor([[ 7.0340e-04, 2.7997e-03, -1.3351e-02, -7.6705e-03, -3.5089e-03,3.9773e-03, 7.4593e-03, 1.2566e-02, 3.5860e-03, -2.7448e-02,-1.3731e-02, -1.0534e-02, -1.3606e-02, -1.5048e-02, -2.8914e-03,-5.0371e-03, -1.3963e-03, 6.0076e-03, -1.1380e-02, -1.4620e-02, 5.2401e-03, 8.4660e-04, -1.5319e-03, -1.6735e-02, 1.1302e-02, 3.6119e-03, 4.6084e-03, -1.3458e-02, 7.7792e-05, 1.4312e-02, 4.9107e-03, -5.0936e-03], [-4.4538e-03, 3.1026e-03, 1.4121e-04, -4.8121e-03, -5.6279e-03, 7.2493e-03, 3.9769e-03, 1.1114e-02, -1.5666e-03, -2.3477e-02, 8.7268e-03, 1.3446e-02, -2.8845e-05, -1.7287e-02, 8.7619e-03, -4.5316e-03, -1.2164e-02, 5.7461e-03, -4.5861e-03, -9.3907e-03, 2.9808e-02, 8.9206e-04, -7.6232e-04, -1.4173e-02, 3.0208e-03, 1.5310e-02, 9.7717e-03, 3.1014e-03, 7.8042e-03, 8.0197e-03, 3.4784e-03, -7.1728e-03]]) # fmt: skip
|
|
self.assertTrue(torch.allclose(hidden_states.mean(1), EXPECTED_MEAN_FAIRSEQ_HIDDEN_STATES, 1e-4))
|
|
|
|
def test_batch_prioritized_routing(self):
|
|
set_seed(0)
|
|
config = NllbMoeConfig(
|
|
num_experts=4, hidden_size=32, d_ff=16, expert_capacity=4, second_expert_policy="random"
|
|
)
|
|
mask = torch.zeros((self.batch_size * self.sequence_length), dtype=torch.bool)
|
|
logits = torch.rand((self.batch_size * self.sequence_length, 4))
|
|
config.batch_prioritized_routing = True
|
|
router = NllbMoeTop2Router(config)
|
|
top_1_mask, _ = router.route_tokens(logits, padding_mask=mask)
|
|
# check that the routing is batch first. One of the last token is routed while expert capacity is very small
|
|
# this means that it had a greater probability of being routed
|
|
assert top_1_mask[-1, 0] == 1
|
|
|
|
def test_second_expert_policy(self):
|
|
config = NllbMoeConfig(
|
|
num_experts=4,
|
|
hidden_size=32,
|
|
d_ff=16,
|
|
expert_capacity=40,
|
|
)
|
|
set_seed(0)
|
|
mask = torch.zeros((self.batch_size * self.sequence_length), dtype=torch.bool)
|
|
logits = torch.rand((self.batch_size * self.sequence_length, 4))
|
|
|
|
set_seed(0)
|
|
config.second_expert_policy = "random"
|
|
router = NllbMoeTop2Router(config)
|
|
top_1_mask, router_probs = router.route_tokens(logits, padding_mask=mask)
|
|
|
|
set_seed(0)
|
|
config.second_expert_policy = "sampling"
|
|
router = NllbMoeTop2Router(config)
|
|
top_1_mask_sp, router_probs_sp = router.route_tokens(logits, padding_mask=mask)
|
|
|
|
set_seed(0)
|
|
config.second_expert_policy = "all"
|
|
router = NllbMoeTop2Router(config)
|
|
top_1_mask_all, router_probs_all = router.route_tokens(logits, padding_mask=mask)
|
|
|
|
# fmt: off
|
|
EXPECTED_ROUTER_ALL = torch.tensor([[0.3902, 0.0000, 0.0000, 0.6098], [0.0000, 0.0000, 0.7770, 0.2230], [0.0000, 0.0000, 0.2726, 0.7274], [0.4221, 0.0000, 0.5779, 0.0000], [0.0000, 0.0000, 0.7810, 0.2190], [0.5518, 0.4482, 0.0000, 0.0000], [0.0000, 0.4060, 0.5940, 0.0000], [0.7340, 0.0000, 0.0000, 0.2660], [0.4778, 0.5222, 0.0000, 0.0000], [0.0000, 0.3984, 0.0000, 0.6016], [0.0000, 0.0548, 0.9452, 0.0000], [0.6796, 0.0000, 0.0000, 0.3204], [0.0700, 0.0000, 0.9300, 0.0000], [0.1854, 0.0000, 0.8146, 0.0000], [0.6775, 0.3225, 0.0000, 0.0000], [0.0000, 0.0000, 0.5027, 0.4973], [0.0000, 0.6577, 0.0000, 0.3423], [0.0000, 0.7767, 0.0000, 0.2233], [0.1944, 0.8056, 0.0000, 0.0000], [0.0000, 0.3073, 0.0000, 0.6927], [0.0000, 0.5655, 0.4345, 0.0000], [0.5791, 0.0000, 0.0000, 0.4209], [0.0440, 0.0000, 0.9560, 0.0000], [0.0083, 0.9917, 0.0000, 0.0000], [0.0000, 0.8395, 0.0000, 0.1605], [0.0000, 0.1458, 0.0000, 0.8542], [0.0000, 0.8534, 0.1466, 0.0000], [0.4938, 0.0000, 0.0000, 0.5062], [0.1329, 0.8671, 0.0000, 0.0000], [0.3058, 0.0000, 0.6942, 0.0000], [0.4458, 0.0000, 0.0000, 0.5542], [0.9053, 0.0947, 0.0000, 0.0000], [0.0000, 0.7563, 0.2437, 0.0000], [0.0000, 0.0000, 0.4096, 0.5904], [0.4551, 0.0000, 0.0000, 0.5449], [0.8502, 0.1498, 0.0000, 0.0000], [0.0000, 0.6312, 0.3688, 0.0000], [0.8920, 0.0000, 0.0000, 0.1080], [0.1913, 0.0000, 0.0000, 0.8087], [0.2491, 0.7509, 0.0000, 0.0000]])
|
|
EXPECTED_ROUTER_SP = torch.tensor([[0.0000, 0.6539, 0.0000, 0.3461], [0.0000, 0.0000, 0.3998, 0.6002], [0.0000, 0.5574, 0.0000, 0.4426], [0.0000, 0.0000, 0.4441, 0.5559], [0.0000, 0.6545, 0.3455, 0.0000], [0.4419, 0.5581, 0.0000, 0.0000], [0.0000, 0.4014, 0.5986, 0.0000], [0.3215, 0.0000, 0.0000, 0.6785], [0.4765, 0.5235, 0.0000, 0.0000], [0.0000, 0.5467, 0.0000, 0.4533], [0.0000, 0.4156, 0.5844, 0.0000], [0.3370, 0.0000, 0.6630, 0.0000], [0.0000, 0.0000, 0.4558, 0.5442], [0.4659, 0.0000, 0.5341, 0.0000], [0.6179, 0.3821, 0.0000, 0.0000], [0.6277, 0.0000, 0.3723, 0.0000], [0.5836, 0.4164, 0.0000, 0.0000], [0.0000, 0.6600, 0.0000, 0.3400], [0.0000, 0.4933, 0.0000, 0.5067], [0.6016, 0.0000, 0.0000, 0.3984], [0.0000, 0.5160, 0.4840, 0.0000], [0.5799, 0.0000, 0.0000, 0.4201], [0.0000, 0.0000, 0.4826, 0.5174], [0.5426, 0.4574, 0.0000, 0.0000], [0.5362, 0.4638, 0.0000, 0.0000], [0.6448, 0.0000, 0.0000, 0.3552], [0.0000, 0.5909, 0.4091, 0.0000], [0.4196, 0.0000, 0.0000, 0.5804], [0.3191, 0.6809, 0.0000, 0.0000], [0.0000, 0.0000, 0.4886, 0.5114], [0.4899, 0.0000, 0.0000, 0.5101], [0.4123, 0.0000, 0.5877, 0.0000], [0.0000, 0.3736, 0.0000, 0.6264], [0.0000, 0.0000, 0.6009, 0.3991], [0.4246, 0.0000, 0.0000, 0.5754], [0.4997, 0.0000, 0.5003, 0.0000], [0.0000, 0.3595, 0.6405, 0.0000], [0.5433, 0.0000, 0.0000, 0.4567], [0.0000, 0.6806, 0.0000, 0.3194], [0.6689, 0.3311, 0.0000, 0.0000]])
|
|
EXPECTED_ROUTER = torch.tensor([[0.4324, 0.5676, 0.0000, 0.0000], [0.0000, 0.4348, 0.0000, 0.5652], [0.4559, 0.5441, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 1.0000], [0.4744, 0.5256, 0.0000, 0.0000], [0.0000, 0.5103, 0.0000, 0.4897], [0.0000, 0.0000, 1.0000, 0.0000], [0.0000, 0.0000, 0.0000, 1.0000], [0.0000, 1.0000, 0.0000, 0.0000], [0.0000, 0.5467, 0.0000, 0.4533], [0.0000, 0.0000, 1.0000, 0.0000], [0.0000, 0.0000, 1.0000, 0.0000], [0.0000, 0.0000, 0.0000, 1.0000], [0.0000, 0.0000, 1.0000, 0.0000], [1.0000, 0.0000, 0.0000, 0.0000], [0.5063, 0.4937, 0.0000, 0.0000], [0.5396, 0.0000, 0.0000, 0.4604], [0.4576, 0.5424, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 1.0000], [0.5134, 0.0000, 0.4866, 0.0000], [0.0000, 0.5160, 0.4840, 0.0000], [0.5439, 0.0000, 0.4561, 0.0000], [0.4849, 0.0000, 0.0000, 0.5151], [0.5426, 0.4574, 0.0000, 0.0000], [0.5362, 0.4638, 0.0000, 0.0000], [1.0000, 0.0000, 0.0000, 0.0000], [0.0000, 1.0000, 0.0000, 0.0000], [0.0000, 0.4448, 0.0000, 0.5552], [0.0000, 1.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.4886, 0.5114], [0.4899, 0.0000, 0.0000, 0.5101], [0.0000, 0.0000, 0.5296, 0.4704], [0.0000, 0.0000, 0.4469, 0.5531], [0.0000, 0.4053, 0.5947, 0.0000], [0.0000, 0.0000, 0.4460, 0.5540], [0.4997, 0.0000, 0.5003, 0.0000], [0.0000, 0.0000, 0.5851, 0.4149], [1.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.5010, 0.4990, 0.0000], [1.0000, 0.0000, 0.0000, 0.0000]])
|
|
|
|
EXPECTED_TOP_1_ALL = torch.LongTensor([[0, 0, 0, 1], [0, 0, 1, 0], [0, 0, 0, 1], [0, 0, 1, 0], [0, 0, 1, 0], [1, 0, 0, 0], [0, 0, 1, 0], [1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1], [0, 0, 1, 0], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0], [1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 1, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1], [0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1], [0, 1, 0, 0], [0, 0, 0, 1], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1], [1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1], [0, 0, 0, 1], [1, 0, 0, 0], [0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1], [0, 1, 0, 0]])
|
|
EXPECTED_TOP_1_SP = torch.LongTensor([[0, 1, 0, 0], [0, 0, 0, 1], [0, 1, 0, 0], [0, 0, 0, 1], [0, 1, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1], [0, 1, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0], [0, 0, 0, 1], [0, 0, 1, 0], [1, 0, 0, 0], [1, 0, 0, 0], [1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1], [1, 0, 0, 0], [0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1], [1, 0, 0, 0], [1, 0, 0, 0], [1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1], [0, 1, 0, 0], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 1, 0], [0, 0, 0, 1], [0, 0, 1, 0], [0, 0, 0, 1], [0, 0, 1, 0], [0, 0, 1, 0], [1, 0, 0, 0], [0, 1, 0, 0], [1, 0, 0, 0]])
|
|
# `sampling` and `random` do not affect the mask of the top_1 router
|
|
# fmt: on
|
|
|
|
torch.testing.assert_close(router_probs_all, EXPECTED_ROUTER_ALL, rtol=1e-4, atol=1e-4)
|
|
torch.testing.assert_close(router_probs_sp, EXPECTED_ROUTER_SP, rtol=1e-4, atol=1e-4)
|
|
torch.testing.assert_close(router_probs, EXPECTED_ROUTER, rtol=1e-4, atol=1e-4)
|
|
|
|
torch.testing.assert_close(top_1_mask_all, EXPECTED_TOP_1_ALL, rtol=1e-4, atol=1e-4)
|
|
torch.testing.assert_close(top_1_mask_sp, EXPECTED_TOP_1_SP, rtol=1e-4, atol=1e-4)
|
|
torch.testing.assert_close(top_1_mask, EXPECTED_TOP_1_SP, rtol=1e-4, atol=1e-4)
|