transformers/tests/models/xlm_prophetnet/test_modeling_xlm_prophetne...

151 lines
7.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team, The Microsoft Research team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
if is_torch_available():
import torch
from transformers import XLMProphetNetForConditionalGeneration, XLMProphetNetTokenizer
@require_torch
class XLMProphetNetModelIntegrationTest(unittest.TestCase):
@slow
def test_pretrained_checkpoint_hidden_states(self):
model = XLMProphetNetForConditionalGeneration.from_pretrained("microsoft/xprophetnet-large-wiki100-cased")
model.to(torch_device)
# encoder-decoder outputs
encoder_ids = torch.tensor([[17, 96208, 103471, 2]]).to(torch_device)
decoder_prev_ids = torch.tensor(
[[2, 250, 9953, 34, 69489, 1620, 32, 118424, 624, 210, 105, 2913, 1032, 351]]
).to(torch_device)
output = model(
input_ids=encoder_ids, attention_mask=None, encoder_outputs=None, decoder_input_ids=decoder_prev_ids
)
output_predited_logis = output[0]
expected_shape = torch.Size((1, 14, 250012))
self.assertEqual(output_predited_logis.shape, expected_shape)
expected_slice = torch.tensor(
[[[-6.3986, -8.2391, 12.5189], [-6.3289, -8.0864, 12.6211], [-6.2418, -8.0445, 12.7968]]]
).to(torch_device)
self.assertTrue(torch.allclose(output_predited_logis[:, :3, :3], expected_slice, atol=1e-4))
# encoder outputs
encoder_outputs = model.prophetnet.encoder(encoder_ids)[0]
expected_encoder_outputs_slice = torch.tensor(
[[[-1.4260, -0.7628, 0.8453], [-1.4719, -0.1391, 0.7807], [-1.7678, 0.0114, 0.4646]]]
).to(torch_device)
expected_shape_encoder = torch.Size((1, 4, 1024))
self.assertEqual(encoder_outputs.shape, expected_shape_encoder)
self.assertTrue(torch.allclose(encoder_outputs[:, :3, :3], expected_encoder_outputs_slice, atol=1e-4))
# decoder outputs
decoder_outputs = model.prophetnet.decoder(
decoder_prev_ids,
encoder_hidden_states=encoder_outputs,
)
predicting_streams = decoder_outputs[1].view(1, model.config.ngram, 14, -1)
predicting_streams_logits = model.lm_head(predicting_streams)
next_first_stream_logits = predicting_streams_logits[:, 0]
self.assertTrue(torch.allclose(next_first_stream_logits[:, :3, :3], expected_slice, atol=1e-4))
@slow
def test_ntg_hidden_states(self):
model = XLMProphetNetForConditionalGeneration.from_pretrained(
"microsoft/xprophetnet-large-wiki100-cased-xglue-ntg"
)
model.to(torch_device)
encoder_ids = torch.tensor([[17, 96208, 103471, 2]]).to(torch_device)
decoder_prev_ids = torch.tensor(
[[2, 250, 9953, 34, 69489, 1620, 32, 118424, 624, 210, 105, 2913, 1032, 351]]
).to(torch_device)
output = model(
input_ids=encoder_ids, attention_mask=None, encoder_outputs=None, decoder_input_ids=decoder_prev_ids
)
output_predited_logis = output[0]
expected_shape = torch.Size((1, 14, 250012))
self.assertEqual(output_predited_logis.shape, expected_shape)
# compare the actual values for a slice.
expected_slice = torch.tensor(
[[[-9.2253, -9.7173, -6.3529], [-7.6701, -9.0145, -1.9382], [-8.0195, -7.0004, -0.1523]]]
).to(torch_device)
self.assertTrue(torch.allclose(output_predited_logis[:, :3, :3], expected_slice, atol=1e-4))
@slow
def test_xprophetnet_ntg_inference(self):
model = XLMProphetNetForConditionalGeneration.from_pretrained(
"microsoft/xprophetnet-large-wiki100-cased-xglue-ntg"
)
model.to(torch_device)
model.config.max_length = 512
tokenizer = XLMProphetNetTokenizer.from_pretrained("microsoft/xprophetnet-large-wiki100-cased-xglue-ntg")
EN_SENTENCE = (
"Microsoft Corporation intends to officially end free support for the Windows 7 operating system after"
" January 14, 2020, according to the official portal of the organization. From that day, users of this"
" system will not be able to receive security updates, which could make their computers vulnerable to"
" cyber attacks."
)
RU_SENTENCE = (
"орпорация Microsoft намерена официально прекратить бесплатную поддержку операционной системы Windows 7"
" после 14 января 2020 года, сообщается на официальном портале организации . С указанного дня пользователи"
" этой системы не смогут получать обновления безопасности, из-за чего их компьютеры могут стать уязвимыми"
" к кибератакам."
)
ZH_SENTENCE = "根据该组织的官方门户网站微软公司打算在2020年1月14日之后正式终止对Windows 7操作系统的免费支持。从那时起该系统的用户将无法接收安全更新这可能会使他们的计算机容易受到网络攻击。"
input_ids = tokenizer(
[EN_SENTENCE, RU_SENTENCE, ZH_SENTENCE], padding=True, max_length=255, return_tensors="pt"
).input_ids
input_ids = input_ids.to(torch_device)
summary_ids = model.generate(
input_ids, num_beams=10, length_penalty=1.0, no_repeat_ngram_size=3, early_stopping=True
)
generated_titles = [tokenizer.decode(g, skip_special_tokens=True) for g in summary_ids]
EXPECTED_TITLE_EN = "Microsoft to end Windows 7 free support after January 14, 2020"
EXPECTED_TITLE_RU = "Microsoft намерена прекратить бесплатную поддержку Windows 7 после 14 января 2020 года"
EXPECTED_TITLE_ZH = "微软打算终止对Windows 7操作系统的免费支持"
self.assertListEqual(
[EXPECTED_TITLE_EN, EXPECTED_TITLE_RU, EXPECTED_TITLE_ZH],
generated_titles,
)
summary_ids_beam1 = model.generate(
input_ids, num_beams=1, length_penalty=1.0, no_repeat_ngram_size=3, early_stopping=True
)
generated_titles_beam1_tok = [
tokenizer.convert_ids_to_tokens(g, skip_special_tokens=True) for g in summary_ids_beam1
]
EXPECTED_TITLE_EN_BEAM1_TOK = "▁Microsoft ▁to ▁end ▁free ▁support ▁for ▁Windows ▁7".split(" ")
EXPECTED_TITLE_RU_BEAM1_TOK = "▁Microsoft ▁намерен а ▁прекрати ть ▁бес плат ную ▁поддержку ▁Windows ▁7 ▁после ▁14 ▁января ▁2020 ▁года".split(
" "
)
EXPECTED_TITLE_ZH_BEAM1_TOK = "微软 公司 打算 终止 对 Windows ▁7 操作 系统的 免费 支持".split(" ")
self.assertListEqual(
[EXPECTED_TITLE_EN_BEAM1_TOK, EXPECTED_TITLE_RU_BEAM1_TOK, EXPECTED_TITLE_ZH_BEAM1_TOK],
generated_titles_beam1_tok,
)