* implement MLukeTokenizer and LukeForMaskedLM

* update tests

* update docs

* add LukeForMaskedLM to check_repo.py

* update README

* fix test and specify the entity pad id in tokenization_(m)luke

* fix EntityPredictionHeadTransform
This commit is contained in:
Ryokan RI 2021-12-07 14:25:28 +09:00 committed by GitHub
parent 4cdb67caba
commit 30646a0a3c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 3107 additions and 234 deletions

View File

@ -275,6 +275,7 @@ Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih.
1. **[LED](https://huggingface.co/docs/transformers/model_doc/led)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan.
1. **[Longformer](https://huggingface.co/docs/transformers/model_doc/longformer)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan.
1. **[LUKE](https://huggingface.co/docs/transformers/model_doc/luke)** (from Studio Ousia) released with the paper [LUKE: Deep Contextualized Entity Representations with Entity-aware Self-attention](https://arxiv.org/abs/2010.01057) by Ikuya Yamada, Akari Asai, Hiroyuki Shindo, Hideaki Takeda, Yuji Matsumoto.
1. **[mLUKE](https://huggingface.co/docs/transformers/model_doc/mluke)** (from Studio Ousia) released with the paper [mLUKE: The Power of Entity Representations in Multilingual Pretrained Language Models](https://arxiv.org/abs/2110.08151) by Ryokan Ri, Ikuya Yamada, and Yoshimasa Tsuruoka.
1. **[LXMERT](https://huggingface.co/docs/transformers/model_doc/lxmert)** (from UNC Chapel Hill) released with the paper [LXMERT: Learning Cross-Modality Encoder Representations from Transformers for Open-Domain Question Answering](https://arxiv.org/abs/1908.07490) by Hao Tan and Mohit Bansal.
1. **[M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100)** (from Facebook) released with the paper [Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125) by Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin.
1. **[MarianMT](https://huggingface.co/docs/transformers/model_doc/marian)** Machine translation models trained using [OPUS](http://opus.nlpl.eu/) data by Jörg Tiedemann. The [Marian Framework](https://marian-nmt.github.io/) is being developed by the Microsoft Translator Team.

View File

@ -261,6 +261,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는
1. **[MBart-50](https://huggingface.co/docs/transformers/model_doc/mbart)** (from Facebook) released with the paper [Multilingual Translation with Extensible Multilingual Pretraining and Finetuning](https://arxiv.org/abs/2008.00401) by Yuqing Tang, Chau Tran, Xian Li, Peng-Jen Chen, Naman Goyal, Vishrav Chaudhary, Jiatao Gu, Angela Fan.
1. **[Megatron-BERT](https://huggingface.co/docs/transformers/model_doc/megatron_bert)** (from NVIDIA) released with the paper [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) by Mohammad Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper and Bryan Catanzaro.
1. **[Megatron-GPT2](https://huggingface.co/docs/transformers/model_doc/megatron_gpt2)** (from NVIDIA) released with the paper [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) by Mohammad Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper and Bryan Catanzaro.
1. **[mLUKE](https://huggingface.co/docs/transformers/model_doc/mluke)** (from Studio Ousia) released with the paper [mLUKE: The Power of Entity Representations in Multilingual Pretrained Language Models](https://arxiv.org/abs/2110.08151) by Ryokan Ri, Ikuya Yamada, and Yoshimasa Tsuruoka.
1. **[MPNet](https://huggingface.co/docs/transformers/model_doc/mpnet)** (from Microsoft Research) released with the paper [MPNet: Masked and Permuted Pre-training for Language Understanding](https://arxiv.org/abs/2004.09297) by Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu.
1. **[MT5](https://huggingface.co/docs/transformers/model_doc/mt5)** (from Google AI) released with the paper [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934) by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel.
1. **[Pegasus](https://huggingface.co/docs/transformers/model_doc/pegasus)** (from Google) released with the paper [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization](https://arxiv.org/abs/1912.08777) by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu.

View File

@ -285,6 +285,7 @@ conda install -c huggingface transformers
1. **[MBart-50](https://huggingface.co/docs/transformers/model_doc/mbart)** (来自 Facebook) 伴随论文 [Multilingual Translation with Extensible Multilingual Pretraining and Finetuning](https://arxiv.org/abs/2008.00401) 由 Yuqing Tang, Chau Tran, Xian Li, Peng-Jen Chen, Naman Goyal, Vishrav Chaudhary, Jiatao Gu, Angela Fan 发布。
1. **[Megatron-BERT](https://huggingface.co/docs/transformers/model_doc/megatron_bert)** (来自 NVIDIA) 伴随论文 [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) 由 Mohammad Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper and Bryan Catanzaro 发布。
1. **[Megatron-GPT2](https://huggingface.co/docs/transformers/model_doc/megatron_gpt2)** (来自 NVIDIA) 伴随论文 [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) 由 Mohammad Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper and Bryan Catanzaro 发布。
1. **[mLUKE](https://huggingface.co/docs/transformers/model_doc/mluke)** (来自 Studio Ousia) 伴随论文 [mLUKE: The Power of Entity Representations in Multilingual Pretrained Language Models](https://arxiv.org/abs/2110.08151) 由 Ryokan Ri, Ikuya Yamada, and Yoshimasa Tsuruoka 发布。
1. **[MPNet](https://huggingface.co/docs/transformers/model_doc/mpnet)** (来自 Microsoft Research) 伴随论文 [MPNet: Masked and Permuted Pre-training for Language Understanding](https://arxiv.org/abs/2004.09297) 由 Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu 发布。
1. **[MT5](https://huggingface.co/docs/transformers/model_doc/mt5)** (来自 Google AI) 伴随论文 [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934) 由 Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel 发布。
1. **[Pegasus](https://huggingface.co/docs/transformers/model_doc/pegasus)** (来自 Google) 伴随论文 [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization](https://arxiv.org/abs/1912.08777) 由 Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu 发布。

View File

@ -297,6 +297,7 @@ conda install -c huggingface transformers
1. **[MBart-50](https://huggingface.co/docs/transformers/model_doc/mbart)** (from Facebook) released with the paper [Multilingual Translation with Extensible Multilingual Pretraining and Finetuning](https://arxiv.org/abs/2008.00401) by Yuqing Tang, Chau Tran, Xian Li, Peng-Jen Chen, Naman Goyal, Vishrav Chaudhary, Jiatao Gu, Angela Fan.
1. **[Megatron-BERT](https://huggingface.co/docs/transformers/model_doc/megatron_bert)** (from NVIDIA) released with the paper [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) by Mohammad Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper and Bryan Catanzaro.
1. **[Megatron-GPT2](https://huggingface.co/docs/transformers/model_doc/megatron_gpt2)** (from NVIDIA) released with the paper [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) by Mohammad Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper and Bryan Catanzaro.
1. **[mLUKE](https://huggingface.co/docs/transformers/model_doc/mluke)** (from Studio Ousia) released with the paper [mLUKE: The Power of Entity Representations in Multilingual Pretrained Language Models](https://arxiv.org/abs/2110.08151) by Ryokan Ri, Ikuya Yamada, and Yoshimasa Tsuruoka.
1. **[MPNet](https://huggingface.co/docs/transformers/model_doc/mpnet)** (from Microsoft Research) released with the paper [MPNet: Masked and Permuted Pre-training for Language Understanding](https://arxiv.org/abs/2004.09297) by Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu.
1. **[MT5](https://huggingface.co/docs/transformers/model_doc/mt5)** (from Google AI) released with the paper [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934) by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel.
1. **[Pegasus](https://huggingface.co/docs/transformers/model_doc/pegasus)** (from Google) released with the paper [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization](https://arxiv.org/abs/1912.08777) by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu.

View File

@ -135,6 +135,7 @@ conversion utilities for the following models.
1. **[LED](model_doc/led)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan.
1. **[Longformer](model_doc/longformer)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan.
1. **[LUKE](model_doc/luke)** (from Studio Ousia) released with the paper [LUKE: Deep Contextualized Entity Representations with Entity-aware Self-attention](https://arxiv.org/abs/2010.01057) by Ikuya Yamada, Akari Asai, Hiroyuki Shindo, Hideaki Takeda, Yuji Matsumoto.
1. **[mLUKE](model_doc/mluke)** (from Studio Ousia) released with the paper [mLUKE: The Power of Entity Representations in Multilingual Pretrained Language Models](https://arxiv.org/abs/2110.08151) by Ryokan Ri, Ikuya Yamada, and Yoshimasa Tsuruoka.
1. **[LXMERT](model_doc/lxmert)** (from UNC Chapel Hill) released with the paper [LXMERT: Learning Cross-Modality Encoder Representations from Transformers for Open-Domain Question Answering](https://arxiv.org/abs/1908.07490) by Hao Tan and Mohit Bansal.
1. **[M2M100](model_doc/m2m_100)** (from Facebook) released with the paper [Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125) by Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin.
1. **[MarianMT](model_doc/marian)** Machine translation models trained using [OPUS](http://opus.nlpl.eu/) data by Jörg Tiedemann. The [Marian Framework](https://marian-nmt.github.io/) is being developed by the Microsoft Translator Team.

View File

@ -20,7 +20,7 @@ Rust library `tokenizers <https://github.com/huggingface/tokenizers>`__. The "Fa
1. a significant speed-up in particular when doing batched tokenization and
2. additional methods to map between the original string (character and words) and the token space (e.g. getting the
index of the token comprising a given character or the span of characters corresponding to a given token). Currently
no "Fast" implementation is available for the SentencePiece-based tokenizers (for T5, ALBERT, CamemBERT, XLMRoBERTa
no "Fast" implementation is available for the SentencePiece-based tokenizers (for T5, ALBERT, CamemBERT, XLM-RoBERTa
and XLNet models).
The base classes :class:`~transformers.PreTrainedTokenizer` and :class:`~transformers.PreTrainedTokenizerFast`

View File

@ -137,6 +137,12 @@ LukeModel
.. autoclass:: transformers.LukeModel
:members: forward
LukeForMaskedLM
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.LukeForMaskedLM
:members: forward
LukeForEntityClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -0,0 +1,66 @@
..
Copyright 2021 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
mLUKE
-----------------------------------------------------------------------------------------------------------------------
Overview
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The mLUKE model was proposed in `mLUKE: The Power of Entity Representations in Multilingual Pretrained Language Models
<https://arxiv.org/abs/2110.08151>`__ by Ryokan Ri, Ikuya Yamada, and Yoshimasa Tsuruoka. It's a multilingual extension
of the `LUKE model <https://arxiv.org/abs/2010.01057>`__ trained on the basis of XLM-RoBERTa.
It is based on XLM-RoBERTa and adds entity embeddings, which helps improve performance on various downstream tasks
involving reasoning about entities such as named entity recognition, extractive question answering, relation
classification, cloze-style knowledge completion.
The abstract from the paper is the following:
*Recent studies have shown that multilingual pretrained language models can be effectively improved with cross-lingual
alignment information from Wikipedia entities. However, existing methods only exploit entity information in pretraining
and do not explicitly use entities in downstream tasks. In this study, we explore the effectiveness of leveraging
entity representations for downstream cross-lingual tasks. We train a multilingual language model with 24 languages
with entity representations and show the model consistently outperforms word-based pretrained models in various
cross-lingual transfer tasks. We also analyze the model and the key insight is that incorporating entity
representations into the input allows us to extract more language-agnostic features. We also evaluate the model with a
multilingual cloze prompt task with the mLAMA dataset. We show that entity-based prompt elicits correct factual
knowledge more likely than using only word representations.*
One can directly plug in the weights of mLUKE into a LUKE model, like so:
.. code-block::
from transformers import LukeModel
model = LukeModel.from_pretrained('studio-ousia/mluke-base')
Note that mLUKE has its own tokenizer, :class:`~transformers.MLukeTokenizer`. You can initialize it as follows:
.. code-block::
from transformers import MLukeTokenizer
tokenizer = MLukeTokenizer.from_pretrained('studio-ousia/mluke-base')
As mLUKE's architecture is equivalent to that of LUKE, one can refer to :doc:`LUKE's documentation page <luke>` for all
tips, code examples and notebooks.
This model was contributed by `ryo0634 <https://huggingface.co/ryo0634>`__. The original code can be found `here
<https://github.com/studio-ousia/luke>`__.
MLukeTokenizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.MLukeTokenizer
:members: __call__, save_vocabulary

View File

@ -17,8 +17,6 @@ Most of the models available in this library are mono-lingual models (English, C
models are available and have a different mechanisms than mono-lingual models. This page details the usage of these
models.
The two models that currently support multiple languages are BERT and XLM.
XLM
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@ -127,3 +125,17 @@ Two XLM-RoBERTa checkpoints can be used for multi-lingual tasks:
- ``xlm-roberta-base`` (Masked language modeling, 100 languages)
- ``xlm-roberta-large`` (Masked language modeling, 100 languages)
mLUKE
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
mLUKE is based on XLM-RoBERTa and further trained on Wikipedia articles in 24 languages with masked language modeling
as well as masked entity prediction objective.
The model can be used in the same way as other models solely based on word-piece inputs, but also can be used with
entity representations to achieve further performance gain, with entity-related tasks such as relation extraction,
named entity recognition and question answering (see :doc:`LUKE <model_doc/luke>`).
Currently, one mLUKE checkpoint is available:
- ``studio-ousia/mluke-base`` (Masked language modeling + Masked entity prediction, 100 languages)

View File

@ -245,6 +245,7 @@ _import_structure = {
"models.mbart": ["MBartConfig"],
"models.mbart50": [],
"models.megatron_bert": ["MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MegatronBertConfig"],
"models.mluke": [],
"models.mmbt": ["MMBTConfig"],
"models.mobilebert": ["MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MobileBertConfig", "MobileBertTokenizer"],
"models.mpnet": ["MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "MPNetConfig", "MPNetTokenizer"],
@ -379,6 +380,7 @@ if is_sentencepiece_available():
_import_structure["models.marian"].append("MarianTokenizer")
_import_structure["models.mbart"].append("MBartTokenizer")
_import_structure["models.mbart50"].append("MBart50Tokenizer")
_import_structure["models.mluke"].append("MLukeTokenizer")
_import_structure["models.mt5"].append("MT5Tokenizer")
_import_structure["models.pegasus"].append("PegasusTokenizer")
_import_structure["models.reformer"].append("ReformerTokenizer")
@ -1037,6 +1039,7 @@ if is_torch_available():
"LukeForEntityClassification",
"LukeForEntityPairClassification",
"LukeForEntitySpanClassification",
"LukeForMaskedLM",
"LukeModel",
"LukePreTrainedModel",
]
@ -2368,6 +2371,7 @@ if TYPE_CHECKING:
from .models.m2m_100 import M2M100Tokenizer
from .models.marian import MarianTokenizer
from .models.mbart import MBart50Tokenizer, MBartTokenizer
from .models.mluke import MLukeTokenizer
from .models.mt5 import MT5Tokenizer
from .models.pegasus import PegasusTokenizer
from .models.reformer import ReformerTokenizer
@ -2904,6 +2908,7 @@ if TYPE_CHECKING:
LukeForEntityClassification,
LukeForEntityPairClassification,
LukeForEntitySpanClassification,
LukeForMaskedLM,
LukeModel,
LukePreTrainedModel,
)

View File

@ -71,6 +71,7 @@ from . import (
mbart50,
megatron_bert,
megatron_gpt2,
mluke,
mmbt,
mobilebert,
mpnet,

View File

@ -178,6 +178,7 @@ else:
("hubert", ("Wav2Vec2CTCTokenizer", None)),
("gpt_neo", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
("luke", ("LukeTokenizer", None)),
("mluke", ("MLukeTokenizer" if is_sentencepiece_available() else None, None)),
("bigbird_pegasus", ("PegasusTokenizer", "PegasusTokenizerFast" if is_tokenizers_available() else None)),
("canine", ("CanineTokenizer", None)),
("bertweet", ("BertweetTokenizer", None)),

View File

@ -32,6 +32,7 @@ if is_torch_available():
"LukeForEntityClassification",
"LukeForEntityPairClassification",
"LukeForEntitySpanClassification",
"LukeForMaskedLM",
"LukeModel",
"LukePreTrainedModel",
]
@ -47,6 +48,7 @@ if TYPE_CHECKING:
LukeForEntityClassification,
LukeForEntityPairClassification,
LukeForEntitySpanClassification,
LukeForMaskedLM,
LukeModel,
LukePreTrainedModel,
)

View File

@ -22,7 +22,7 @@ import torch
import torch.utils.checkpoint
from torch import nn
from ...activations import ACT2FN
from ...activations import ACT2FN, gelu
from ...file_utils import (
ModelOutput,
add_start_docstrings,
@ -110,6 +110,49 @@ class BaseLukeModelOutput(BaseModelOutput):
entity_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class LukeMaskedLMOutput(ModelOutput):
"""
Base class for model's outputs, with potential hidden states and attentions.
Args:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
The sum of masked language modeling (MLM) loss and entity prediction loss.
mlm_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
Masked language modeling (MLM) loss.
mep_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
Masked entity prediction (MEP) loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
entity_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the entity prediction head (scores for each entity vocabulary token before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
entity_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output
of each layer plus the initial entity embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[torch.FloatTensor] = None
mlm_loss: Optional[torch.FloatTensor] = None
mep_loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
entity_logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
entity_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class EntityClassificationOutput(ModelOutput):
"""
@ -674,6 +717,38 @@ class LukePooler(nn.Module):
return pooled_output
class EntityPredictionHeadTransform(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.entity_emb_size)
if isinstance(config.hidden_act, str):
self.transform_act_fn = ACT2FN[config.hidden_act]
else:
self.transform_act_fn = config.hidden_act
self.LayerNorm = nn.LayerNorm(config.entity_emb_size, eps=config.layer_norm_eps)
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states
class EntityPredictionHead(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.transform = EntityPredictionHeadTransform(config)
self.decoder = nn.Linear(config.entity_emb_size, config.entity_vocab_size, bias=False)
self.bias = nn.Parameter(torch.zeros(config.entity_vocab_size))
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states) + self.bias
return hidden_states
class LukePreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
@ -1013,6 +1088,170 @@ def create_position_ids_from_input_ids(input_ids, padding_idx):
return incremental_indices.long() + padding_idx
# Copied from transformers.models.roberta.modeling_roberta.RobertaLMHead
class LukeLMHead(nn.Module):
"""Roberta Head for masked language modeling."""
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
self.decoder.bias = self.bias
def forward(self, features, **kwargs):
x = self.dense(features)
x = gelu(x)
x = self.layer_norm(x)
# project back to size of vocabulary with bias
x = self.decoder(x)
return x
def _tie_weights(self):
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias
@add_start_docstrings(
"""
The LUKE model with a language modeling head and entity prediction head on top for masked language modeling and
masked entity prediction.
""",
LUKE_START_DOCSTRING,
)
class LukeForMaskedLM(LukePreTrainedModel):
_keys_to_ignore_on_save = [
r"lm_head.decoder.weight",
r"lm_head.decoder.bias",
r"entity_predictions.decoder.weight",
]
_keys_to_ignore_on_load_missing = [
r"position_ids",
r"lm_head.decoder.weight",
r"lm_head.decoder.bias",
r"entity_predictions.decoder.weight",
]
def __init__(self, config):
super().__init__(config)
self.luke = LukeModel(config)
self.lm_head = LukeLMHead(config)
self.entity_predictions = EntityPredictionHead(config)
self.loss_fn = nn.CrossEntropyLoss(ignore_index=-1)
# Initialize weights and apply final processing
self.post_init()
def tie_weights(self):
super().tie_weights()
self._tie_or_clone_weights(self.entity_predictions.decoder, self.luke.entity_embeddings.entity_embeddings)
def get_output_embeddings(self):
return self.lm_head.decoder
def set_output_embeddings(self, new_embeddings):
self.lm_head.decoder = new_embeddings
@add_start_docstrings_to_model_forward(LUKE_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=LukeMaskedLMOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
entity_ids=None,
entity_attention_mask=None,
entity_token_type_ids=None,
entity_position_ids=None,
labels=None,
entity_labels=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
entity_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, entity_length)`, `optional`):
Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
Returns:
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.luke(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
entity_ids=entity_ids,
entity_attention_mask=entity_attention_mask,
entity_token_type_ids=entity_token_type_ids,
entity_position_ids=entity_position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
)
loss = None
mlm_loss = None
logits = self.lm_head(outputs.last_hidden_state)
if labels is not None:
mlm_loss = self.loss_fn(logits.view(-1, self.config.vocab_size), labels.view(-1))
if loss is None:
loss = mlm_loss
mep_loss = None
entity_logits = self.entity_predictions(outputs.entity_last_hidden_state)
if entity_labels is not None:
mep_loss = self.loss_fn(entity_logits.view(-1, self.config.entity_vocab_size), entity_labels.view(-1))
if loss is None:
loss = mep_loss
else:
loss = loss + mep_loss
if not return_dict:
output = (logits, entity_logits, outputs.hidden_states, outputs.entity_hidden_states, outputs.attentions)
if mlm_loss is not None and mep_loss is not None:
return (loss, mlm_loss, mep_loss) + output
elif mlm_loss is not None:
return (loss, mlm_loss) + output
elif mep_loss is not None:
return (loss, mep_loss) + output
else:
return output
return LukeMaskedLMOutput(
loss=loss,
mlm_loss=mlm_loss,
mep_loss=mep_loss,
logits=logits,
entity_logits=entity_logits,
hidden_states=outputs.hidden_states,
entity_hidden_states=outputs.entity_hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""
The LUKE model with a classification head on top (a linear layer on top of the hidden state of the first entity

View File

@ -312,17 +312,15 @@ class LukeTokenizer(RobertaTokenizer):
# Input type checking for clearer error
is_valid_single_text = isinstance(text, str)
is_valid_batch_text = isinstance(text, (list, tuple)) and (len(text) == 0 or (isinstance(text[0], str)))
assert (
is_valid_single_text or is_valid_batch_text
), "text input must be of type `str` (single example) or `List[str]` (batch)."
if not (is_valid_single_text or is_valid_batch_text):
raise ValueError("text input must be of type `str` (single example) or `List[str]` (batch).")
is_valid_single_text_pair = isinstance(text_pair, str)
is_valid_batch_text_pair = isinstance(text_pair, (list, tuple)) and (
len(text_pair) == 0 or isinstance(text_pair[0], str)
)
assert (
text_pair is None or is_valid_single_text_pair or is_valid_batch_text_pair
), "text_pair input must be of type `str` (single example) or `List[str]` (batch)."
if not (text_pair is None or is_valid_single_text_pair or is_valid_batch_text_pair):
raise ValueError("text_pair input must be of type `str` (single example) or `List[str]` (batch).")
is_batched = bool(isinstance(text, (list, tuple)))
@ -391,105 +389,6 @@ class LukeTokenizer(RobertaTokenizer):
**kwargs,
)
@add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
def encode_plus(
self,
text: Union[TextInput],
text_pair: Optional[Union[TextInput]] = None,
entity_spans: Optional[EntitySpanInput] = None,
entity_spans_pair: Optional[EntitySpanInput] = None,
entities: Optional[EntityInput] = None,
entities_pair: Optional[EntityInput] = None,
add_special_tokens: bool = True,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = False,
max_length: Optional[int] = None,
max_entity_length: Optional[int] = None,
stride: int = 0,
is_split_into_words: Optional[bool] = False,
pad_to_multiple_of: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_length: bool = False,
verbose: bool = True,
**kwargs
) -> BatchEncoding:
"""
Tokenize and prepare for the model a sequence or a pair of sequences.
.. warning:: This method is deprecated, ``__call__`` should be used instead.
Args:
text (:obj:`str`):
The first sequence to be encoded. Each sequence must be a string.
text_pair (:obj:`str`):
The second sequence to be encoded. Each sequence must be a string.
entity_spans (:obj:`List[Tuple[int, int]]`, :obj:`List[List[Tuple[int, int]]]`, `optional`)::
The first sequence of entity spans to be encoded. The sequence consists of tuples each with two
integers denoting character-based start and end positions of entities. If you specify
:obj:`"entity_classification"` or :obj:`"entity_pair_classification"` as the ``task`` argument in the
constructor, the length of each sequence must be 1 or 2, respectively. If you specify ``entities``, the
length of the sequence must be equal to the length of ``entities``.
entity_spans_pair (:obj:`List[Tuple[int, int]]`, :obj:`List[List[Tuple[int, int]]]`, `optional`)::
The second sequence of entity spans to be encoded. The sequence consists of tuples each with two
integers denoting character-based start and end positions of entities. If you specify the ``task``
argument in the constructor, this argument is ignored. If you specify ``entities_pair``, the length of
the sequence must be equal to the length of ``entities_pair``.
entities (:obj:`List[str]` `optional`)::
The first sequence of entities to be encoded. The sequence consists of strings representing entities,
i.e., special entities (e.g., [MASK]) or entity titles of Wikipedia (e.g., Los Angeles). This argument
is ignored if you specify the ``task`` argument in the constructor. The length of the sequence must be
equal to the length of ``entity_spans``. If you specify ``entity_spans`` without specifying this
argument, the entity sequence is automatically constructed by filling it with the [MASK] entity.
entities_pair (:obj:`List[str]`, :obj:`List[List[str]]`, `optional`)::
The second sequence of entities to be encoded. The sequence consists of strings representing entities,
i.e., special entities (e.g., [MASK]) or entity titles of Wikipedia (e.g., Los Angeles). This argument
is ignored if you specify the ``task`` argument in the constructor. The length of the sequence must be
equal to the length of ``entity_spans_pair``. If you specify ``entity_spans_pair`` without specifying
this argument, the entity sequence is automatically constructed by filling it with the [MASK] entity.
max_entity_length (:obj:`int`, `optional`):
The maximum length of the entity sequence.
"""
# Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
padding=padding,
truncation=truncation,
max_length=max_length,
pad_to_multiple_of=pad_to_multiple_of,
verbose=verbose,
**kwargs,
)
return self._encode_plus(
text=text,
text_pair=text_pair,
entity_spans=entity_spans,
entity_spans_pair=entity_spans_pair,
entities=entities,
entities_pair=entities_pair,
add_special_tokens=add_special_tokens,
padding_strategy=padding_strategy,
truncation_strategy=truncation_strategy,
max_length=max_length,
max_entity_length=max_entity_length,
stride=stride,
is_split_into_words=is_split_into_words,
pad_to_multiple_of=pad_to_multiple_of,
return_tensors=return_tensors,
return_token_type_ids=return_token_type_ids,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
return_length=return_length,
verbose=verbose,
**kwargs,
)
def _encode_plus(
self,
text: Union[TextInput],
@ -571,89 +470,6 @@ class LukeTokenizer(RobertaTokenizer):
verbose=verbose,
)
@add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
def batch_encode_plus(
self,
batch_text_or_text_pairs: Union[List[TextInput], List[TextInputPair]],
batch_entity_spans_or_entity_spans_pairs: Optional[
Union[List[EntitySpanInput], List[Tuple[EntitySpanInput, EntitySpanInput]]]
] = None,
batch_entities_or_entities_pairs: Optional[
Union[List[EntityInput], List[Tuple[EntityInput, EntityInput]]]
] = None,
add_special_tokens: bool = True,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = False,
max_length: Optional[int] = None,
max_entity_length: Optional[int] = None,
stride: int = 0,
is_split_into_words: Optional[bool] = False,
pad_to_multiple_of: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_length: bool = False,
verbose: bool = True,
**kwargs
) -> BatchEncoding:
"""
Tokenize and prepare for the model a list of sequences or a list of pairs of sequences.
.. warning::
This method is deprecated, ``__call__`` should be used instead.
Args:
batch_text_or_text_pairs (:obj:`List[str]`, :obj:`List[Tuple[str, str]]`):
Batch of sequences or pair of sequences to be encoded. This can be a list of string or a list of pair
of string (see details in ``encode_plus``).
batch_entity_spans_or_entity_spans_pairs (:obj:`List[List[Tuple[int, int]]]`,
:obj:`List[Tuple[List[Tuple[int, int]], List[Tuple[int, int]]]]`, `optional`)::
Batch of entity span sequences or pairs of entity span sequences to be encoded (see details in
``encode_plus``).
batch_entities_or_entities_pairs (:obj:`List[List[str]]`, :obj:`List[Tuple[List[str], List[str]]]`,
`optional`):
Batch of entity sequences or pairs of entity sequences to be encoded (see details in ``encode_plus``).
max_entity_length (:obj:`int`, `optional`):
The maximum length of the entity sequence.
"""
# Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
padding=padding,
truncation=truncation,
max_length=max_length,
pad_to_multiple_of=pad_to_multiple_of,
verbose=verbose,
**kwargs,
)
return self._batch_encode_plus(
batch_text_or_text_pairs=batch_text_or_text_pairs,
batch_entity_spans_or_entity_spans_pairs=batch_entity_spans_or_entity_spans_pairs,
batch_entities_or_entities_pairs=batch_entities_or_entities_pairs,
add_special_tokens=add_special_tokens,
padding_strategy=padding_strategy,
truncation_strategy=truncation_strategy,
max_length=max_length,
max_entity_length=max_entity_length,
stride=stride,
is_split_into_words=is_split_into_words,
pad_to_multiple_of=pad_to_multiple_of,
return_tensors=return_tensors,
return_token_type_ids=return_token_type_ids,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
return_length=return_length,
verbose=verbose,
**kwargs,
)
def _batch_encode_plus(
self,
batch_text_or_text_pairs: Union[List[TextInput], List[TextInputPair]],
@ -713,11 +529,12 @@ class LukeTokenizer(RobertaTokenizer):
entity_spans, entity_spans_pair = None, None
if batch_entity_spans_or_entity_spans_pairs is not None:
entity_spans_or_entity_spans_pairs = batch_entity_spans_or_entity_spans_pairs[index]
if entity_spans_or_entity_spans_pairs:
if isinstance(entity_spans_or_entity_spans_pairs[0][0], int):
entity_spans, entity_spans_pair = entity_spans_or_entity_spans_pairs, None
else:
entity_spans, entity_spans_pair = entity_spans_or_entity_spans_pairs
if len(entity_spans_or_entity_spans_pairs) > 0 and isinstance(
entity_spans_or_entity_spans_pairs[0], list
):
entity_spans, entity_spans_pair = entity_spans_or_entity_spans_pairs
else:
entity_spans, entity_spans_pair = entity_spans_or_entity_spans_pairs, None
(
first_ids,
@ -761,6 +578,25 @@ class LukeTokenizer(RobertaTokenizer):
return BatchEncoding(batch_outputs)
def _check_entity_input_format(self, entities: Optional[EntityInput], entity_spans: Optional[EntitySpanInput]):
if not isinstance(entity_spans, list):
raise ValueError("entity_spans should be given as a list")
elif len(entity_spans) > 0 and not isinstance(entity_spans[0], tuple):
raise ValueError(
"entity_spans should be given as a list of tuples " "containing the start and end character indices"
)
if entities is not None:
if not isinstance(entities, list):
raise ValueError("If you specify entities, they should be given as a list")
if len(entities) > 0 and not isinstance(entities[0], str):
raise ValueError("If you specify entities, they should be given as a list of entity names")
if len(entities) != len(entity_spans):
raise ValueError("If you specify entities, entities and entity_spans must be the same length")
def _create_input_sequence(
self,
text: Union[TextInput],
@ -816,15 +652,7 @@ class LukeTokenizer(RobertaTokenizer):
if entity_spans is None:
first_ids = get_input_ids(text)
else:
assert isinstance(entity_spans, list) and (
len(entity_spans) == 0 or isinstance(entity_spans[0], tuple)
), "entity_spans should be given as a list of tuples containing the start and end character indices"
assert entities is None or (
isinstance(entities, list) and (len(entities) == 0 or isinstance(entities[0], str))
), "If you specify entities, they should be given as a list of entity names"
assert entities is None or len(entities) == len(
entity_spans
), "If you specify entities, entities and entity_spans must be the same length"
self._check_entity_input_format(entities, entity_spans)
first_ids, first_entity_token_spans = get_input_ids_and_entity_token_spans(text, entity_spans)
if entities is None:
@ -836,16 +664,7 @@ class LukeTokenizer(RobertaTokenizer):
if entity_spans_pair is None:
second_ids = get_input_ids(text_pair)
else:
assert isinstance(entity_spans_pair, list) and (
len(entity_spans_pair) == 0 or isinstance(entity_spans_pair[0], tuple)
), "entity_spans_pair should be given as a list of tuples containing the start and end character indices"
assert entities_pair is None or (
isinstance(entities_pair, list)
and (len(entities_pair) == 0 or isinstance(entities_pair[0], str))
), "If you specify entities_pair, they should be given as a list of entity names"
assert entities_pair is None or len(entities_pair) == len(
entity_spans_pair
), "If you specify entities_pair, entities_pair and entity_spans_pair must be the same length"
self._check_entity_input_format(entities_pair, entity_spans_pair)
second_ids, second_entity_token_spans = get_input_ids_and_entity_token_spans(
text_pair, entity_spans_pair
@ -856,10 +675,11 @@ class LukeTokenizer(RobertaTokenizer):
second_entity_ids = [self.entity_vocab.get(entity, unk_entity_id) for entity in entities_pair]
elif self.task == "entity_classification":
assert (
isinstance(entity_spans, list) and len(entity_spans) == 1 and isinstance(entity_spans[0], tuple)
), "Entity spans should be a list containing a single tuple containing the start and end character indices of an entity"
if not (isinstance(entity_spans, list) and len(entity_spans) == 1 and isinstance(entity_spans[0], tuple)):
raise ValueError(
"Entity spans should be a list containing a single tuple "
"containing the start and end character indices of an entity"
)
first_entity_ids = [self.entity_vocab["[MASK]"]]
first_ids, first_entity_token_spans = get_input_ids_and_entity_token_spans(text, entity_spans)
@ -876,12 +696,16 @@ class LukeTokenizer(RobertaTokenizer):
first_entity_token_spans = [(entity_token_start, entity_token_end + 2)]
elif self.task == "entity_pair_classification":
assert (
if not (
isinstance(entity_spans, list)
and len(entity_spans) == 2
and isinstance(entity_spans[0], tuple)
and isinstance(entity_spans[1], tuple)
), "Entity spans should be provided as a list of tuples, each tuple containing the start and end character indices of an entity"
):
raise ValueError(
"Entity spans should be provided as a list of two tuples, "
"each tuple containing the start and end character indices of an entity"
)
head_span, tail_span = entity_spans
first_entity_ids = [self.entity_vocab["[MASK]"], self.entity_vocab["[MASK2]"]]
@ -907,9 +731,11 @@ class LukeTokenizer(RobertaTokenizer):
elif self.task == "entity_span_classification":
mask_entity_id = self.entity_vocab["[MASK]"]
assert isinstance(entity_spans, list) and isinstance(
entity_spans[0], tuple
), "Entity spans should be provided as a list of tuples, each tuple containing the start and end character indices of an entity"
if not (isinstance(entity_spans, list) and len(entity_spans) > 0 and isinstance(entity_spans[0], tuple)):
raise ValueError(
"Entity spans should be provided as a list of tuples, "
"each tuple containing the start and end character indices of an entity"
)
first_ids, first_entity_token_spans = get_input_ids_and_entity_token_spans(text, entity_spans)
first_entity_ids = [mask_entity_id] * len(entity_spans)
@ -1218,7 +1044,6 @@ class LukeTokenizer(RobertaTokenizer):
self._eventual_warn_about_too_long_sequence(encoded_inputs["input_ids"], max_length, verbose)
# Padding
# To do: add padding of entities
if padding_strategy != PaddingStrategy.DO_NOT_PAD or return_attention_mask:
encoded_inputs = self.pad(
encoded_inputs,
@ -1369,9 +1194,8 @@ class LukeTokenizer(RobertaTokenizer):
return BatchEncoding(encoded_inputs, tensor_type=return_tensors)
batch_size = len(required_input)
assert all(
len(v) == batch_size for v in encoded_inputs.values()
), "Some items in the output dictionary have a different batch size than others."
if any(len(v) != batch_size for v in encoded_inputs.values()):
raise ValueError("Some items in the output dictionary have a different batch size than others.")
if padding_strategy == PaddingStrategy.LONGEST:
max_length = max(len(inputs) for inputs in required_input)
@ -1487,7 +1311,9 @@ class LukeTokenizer(RobertaTokenizer):
encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference
encoded_inputs["input_ids"] = encoded_inputs["input_ids"] + [self.pad_token_id] * difference
if entities_provided:
encoded_inputs["entity_ids"] = encoded_inputs["entity_ids"] + [0] * entity_difference
encoded_inputs["entity_ids"] = (
encoded_inputs["entity_ids"] + [self.entity_vocab["[PAD]"]] * entity_difference
)
encoded_inputs["entity_position_ids"] = (
encoded_inputs["entity_position_ids"] + [[-1] * self.max_mention_length] * entity_difference
)
@ -1516,7 +1342,9 @@ class LukeTokenizer(RobertaTokenizer):
encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
encoded_inputs["input_ids"] = [self.pad_token_id] * difference + encoded_inputs["input_ids"]
if entities_provided:
encoded_inputs["entity_ids"] = [0] * entity_difference + encoded_inputs["entity_ids"]
encoded_inputs["entity_ids"] = [self.entity_vocab["[PAD]"]] * entity_difference + encoded_inputs[
"entity_ids"
]
encoded_inputs["entity_position_ids"] = [
[-1] * self.max_mention_length
] * entity_difference + encoded_inputs["entity_position_ids"]

View File

@ -0,0 +1,38 @@
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...file_utils import _LazyModule, is_sentencepiece_available
_import_structure = {}
if is_sentencepiece_available():
_import_structure["tokenization_mluke"] = ["MLukeTokenizer"]
if TYPE_CHECKING:
if is_sentencepiece_available():
from .tokenization_mluke import MLukeTokenizer
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)

View File

@ -0,0 +1,228 @@
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert mLUKE checkpoint."""
import argparse
import json
import os
from collections import OrderedDict
import torch
from transformers import LukeConfig, LukeForMaskedLM, MLukeTokenizer, XLMRobertaTokenizer
from transformers.tokenization_utils_base import AddedToken
@torch.no_grad()
def convert_luke_checkpoint(checkpoint_path, metadata_path, entity_vocab_path, pytorch_dump_folder_path, model_size):
# Load configuration defined in the metadata file
with open(metadata_path) as metadata_file:
metadata = json.load(metadata_file)
config = LukeConfig(use_entity_aware_attention=True, **metadata["model_config"])
# Load in the weights from the checkpoint_path
state_dict = torch.load(checkpoint_path, map_location="cpu")["module"]
# Load the entity vocab file
entity_vocab = load_original_entity_vocab(entity_vocab_path)
# add an entry for [MASK2]
entity_vocab["[MASK2]"] = max(entity_vocab.values()) + 1
config.entity_vocab_size += 1
tokenizer = XLMRobertaTokenizer.from_pretrained(metadata["model_config"]["bert_model_name"])
# Add special tokens to the token vocabulary for downstream tasks
entity_token_1 = AddedToken("<ent>", lstrip=False, rstrip=False)
entity_token_2 = AddedToken("<ent2>", lstrip=False, rstrip=False)
tokenizer.add_special_tokens(dict(additional_special_tokens=[entity_token_1, entity_token_2]))
config.vocab_size += 2
print(f"Saving tokenizer to {pytorch_dump_folder_path}")
tokenizer.save_pretrained(pytorch_dump_folder_path)
with open(os.path.join(pytorch_dump_folder_path, "tokenizer_config.json"), "r") as f:
tokenizer_config = json.load(f)
tokenizer_config["tokenizer_class"] = "MLukeTokenizer"
with open(os.path.join(pytorch_dump_folder_path, "tokenizer_config.json"), "w") as f:
json.dump(tokenizer_config, f)
with open(os.path.join(pytorch_dump_folder_path, MLukeTokenizer.vocab_files_names["entity_vocab_file"]), "w") as f:
json.dump(entity_vocab, f)
tokenizer = MLukeTokenizer.from_pretrained(pytorch_dump_folder_path)
# Initialize the embeddings of the special tokens
ent_init_index = tokenizer.convert_tokens_to_ids(["@"])[0]
ent2_init_index = tokenizer.convert_tokens_to_ids(["#"])[0]
word_emb = state_dict["embeddings.word_embeddings.weight"]
ent_emb = word_emb[ent_init_index].unsqueeze(0)
ent2_emb = word_emb[ent2_init_index].unsqueeze(0)
state_dict["embeddings.word_embeddings.weight"] = torch.cat([word_emb, ent_emb, ent2_emb])
# add special tokens for 'entity_predictions.bias'
for bias_name in ["lm_head.decoder.bias", "lm_head.bias"]:
decoder_bias = state_dict[bias_name]
ent_decoder_bias = decoder_bias[ent_init_index].unsqueeze(0)
ent2_decoder_bias = decoder_bias[ent2_init_index].unsqueeze(0)
state_dict[bias_name] = torch.cat([decoder_bias, ent_decoder_bias, ent2_decoder_bias])
# Initialize the query layers of the entity-aware self-attention mechanism
for layer_index in range(config.num_hidden_layers):
for matrix_name in ["query.weight", "query.bias"]:
prefix = f"encoder.layer.{layer_index}.attention.self."
state_dict[prefix + "w2e_" + matrix_name] = state_dict[prefix + matrix_name]
state_dict[prefix + "e2w_" + matrix_name] = state_dict[prefix + matrix_name]
state_dict[prefix + "e2e_" + matrix_name] = state_dict[prefix + matrix_name]
# Initialize the embedding of the [MASK2] entity using that of the [MASK] entity for downstream tasks
entity_emb = state_dict["entity_embeddings.entity_embeddings.weight"]
entity_mask_emb = entity_emb[entity_vocab["[MASK]"]].unsqueeze(0)
state_dict["entity_embeddings.entity_embeddings.weight"] = torch.cat([entity_emb, entity_mask_emb])
# add [MASK2] for 'entity_predictions.bias'
entity_prediction_bias = state_dict["entity_predictions.bias"]
entity_mask_bias = entity_prediction_bias[entity_vocab["[MASK]"]].unsqueeze(0)
state_dict["entity_predictions.bias"] = torch.cat([entity_prediction_bias, entity_mask_bias])
model = LukeForMaskedLM(config=config).eval()
state_dict.pop("entity_predictions.decoder.weight")
state_dict.pop("lm_head.decoder.weight")
state_dict.pop("lm_head.decoder.bias")
state_dict_for_hugging_face = OrderedDict()
for key, value in state_dict.items():
if not (key.startswith("lm_head") or key.startswith("entity_predictions")):
state_dict_for_hugging_face[f"luke.{key}"] = state_dict[key]
else:
state_dict_for_hugging_face[key] = state_dict[key]
missing_keys, unexpected_keys = model.load_state_dict(state_dict_for_hugging_face, strict=False)
if set(unexpected_keys) != {"luke.embeddings.position_ids"}:
raise ValueError(f"Unexpected unexpected_keys: {unexpected_keys}")
if set(missing_keys) != {
"lm_head.decoder.weight",
"lm_head.decoder.bias",
"entity_predictions.decoder.weight",
}:
raise ValueError(f"Unexpected missing_keys: {missing_keys}")
model.tie_weights()
assert (model.luke.embeddings.word_embeddings.weight == model.lm_head.decoder.weight).all()
assert (model.luke.entity_embeddings.entity_embeddings.weight == model.entity_predictions.decoder.weight).all()
# Check outputs
tokenizer = MLukeTokenizer.from_pretrained(pytorch_dump_folder_path, task="entity_classification")
text = "ISO 639-3 uses the code fas for the dialects spoken across Iran and アフガニスタン (Afghanistan)."
span = (0, 9)
encoding = tokenizer(text, entity_spans=[span], return_tensors="pt")
outputs = model(**encoding)
# Verify word hidden states
if model_size == "large":
raise NotImplementedError
else: # base
expected_shape = torch.Size((1, 33, 768))
expected_slice = torch.tensor([[0.0892, 0.0596, -0.2819], [0.0134, 0.1199, 0.0573], [-0.0169, 0.0927, 0.0644]])
if not (outputs.last_hidden_state.shape == expected_shape):
raise ValueError(
f"Outputs.last_hidden_state.shape is {outputs.last_hidden_state.shape}, Expected shape is {expected_shape}"
)
if not torch.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4):
raise ValueError
# Verify entity hidden states
if model_size == "large":
raise NotImplementedError
else: # base
expected_shape = torch.Size((1, 1, 768))
expected_slice = torch.tensor([[-0.1482, 0.0609, 0.0322]])
if not (outputs.entity_last_hidden_state.shape == expected_shape):
raise ValueError(
f"Outputs.entity_last_hidden_state.shape is {outputs.entity_last_hidden_state.shape}, Expected shape is {expected_shape}"
)
if not torch.allclose(outputs.entity_last_hidden_state[0, :3, :3], expected_slice, atol=1e-4):
raise ValueError
# Verify masked word/entity prediction
tokenizer = MLukeTokenizer.from_pretrained(pytorch_dump_folder_path)
text = "Tokyo is the capital of <mask>."
span = (24, 30)
encoding = tokenizer(text, entity_spans=[span], return_tensors="pt")
outputs = model(**encoding)
input_ids = encoding["input_ids"][0].tolist()
mask_position_id = input_ids.index(tokenizer.convert_tokens_to_ids("<mask>"))
predicted_id = outputs.logits[0][mask_position_id].argmax(dim=-1)
assert "Japan" == tokenizer.decode(predicted_id)
predicted_entity_id = outputs.entity_logits[0][0].argmax().item()
multilingual_predicted_entities = [
entity for entity, entity_id in tokenizer.entity_vocab.items() if entity_id == predicted_entity_id
]
assert [e for e in multilingual_predicted_entities if e.startswith("en:")][0] == "en:Japan"
# Finally, save our PyTorch model and tokenizer
print("Saving PyTorch model to {}".format(pytorch_dump_folder_path))
model.save_pretrained(pytorch_dump_folder_path)
def load_original_entity_vocab(entity_vocab_path):
SPECIAL_TOKENS = ["[MASK]", "[PAD]", "[UNK]"]
data = [json.loads(line) for line in open(entity_vocab_path)]
new_mapping = {}
for entry in data:
entity_id = entry["id"]
for entity_name, language in entry["entities"]:
if entity_name in SPECIAL_TOKENS:
new_mapping[entity_name] = entity_id
break
new_entity_name = f"{language}:{entity_name}"
new_mapping[new_entity_name] = entity_id
return new_mapping
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument("--checkpoint_path", type=str, help="Path to a pytorch_model.bin file.")
parser.add_argument(
"--metadata_path", default=None, type=str, help="Path to a metadata.json file, defining the configuration."
)
parser.add_argument(
"--entity_vocab_path",
default=None,
type=str,
help="Path to an entity_vocab.tsv file, containing the entity vocabulary.",
)
parser.add_argument(
"--pytorch_dump_folder_path", default=None, type=str, help="Path to where to dump the output PyTorch model."
)
parser.add_argument(
"--model_size", default="base", type=str, choices=["base", "large"], help="Size of the model to be converted."
)
args = parser.parse_args()
convert_luke_checkpoint(
args.checkpoint_path,
args.metadata_path,
args.entity_vocab_path,
args.pytorch_dump_folder_path,
args.model_size,
)

File diff suppressed because it is too large Load Diff

View File

@ -3020,6 +3020,18 @@ class LukeForEntitySpanClassification:
requires_backends(self, ["torch"])
class LukeForMaskedLM:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
def forward(self, *args, **kwargs):
requires_backends(self, ["torch"])
class LukeModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])

View File

@ -110,6 +110,15 @@ class MBartTokenizer:
requires_backends(cls, ["sentencepiece"])
class MLukeTokenizer:
def __init__(self, *args, **kwargs):
requires_backends(self, ["sentencepiece"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["sentencepiece"])
class MT5Tokenizer:
def __init__(self, *args, **kwargs):
requires_backends(self, ["sentencepiece"])

View File

@ -29,6 +29,7 @@ if is_torch_available():
LukeForEntityClassification,
LukeForEntityPairClassification,
LukeForEntitySpanClassification,
LukeForMaskedLM,
LukeModel,
LukeTokenizer,
)
@ -138,12 +139,17 @@ class LukeModelTester:
)
sequence_labels = None
labels = None
entity_labels = None
entity_classification_labels = None
entity_pair_classification_labels = None
entity_span_classification_labels = None
if self.use_labels:
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
entity_labels = ids_tensor([self.batch_size, self.entity_length], self.entity_vocab_size)
entity_classification_labels = ids_tensor([self.batch_size], self.num_entity_classification_labels)
entity_pair_classification_labels = ids_tensor(
[self.batch_size], self.num_entity_pair_classification_labels
@ -164,6 +170,8 @@ class LukeModelTester:
entity_token_type_ids,
entity_position_ids,
sequence_labels,
labels,
entity_labels,
entity_classification_labels,
entity_pair_classification_labels,
entity_span_classification_labels,
@ -199,6 +207,8 @@ class LukeModelTester:
entity_token_type_ids,
entity_position_ids,
sequence_labels,
labels,
entity_labels,
entity_classification_labels,
entity_pair_classification_labels,
entity_span_classification_labels,
@ -226,6 +236,44 @@ class LukeModelTester:
result = model(input_ids)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
def create_and_check_for_masked_lm(
self,
config,
input_ids,
attention_mask,
token_type_ids,
entity_ids,
entity_attention_mask,
entity_token_type_ids,
entity_position_ids,
sequence_labels,
labels,
entity_labels,
entity_classification_labels,
entity_pair_classification_labels,
entity_span_classification_labels,
):
config.num_labels = self.num_entity_classification_labels
model = LukeForMaskedLM(config)
model.to(torch_device)
model.eval()
result = model(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
entity_ids=entity_ids,
entity_attention_mask=entity_attention_mask,
entity_token_type_ids=entity_token_type_ids,
entity_position_ids=entity_position_ids,
labels=labels,
entity_labels=entity_labels,
)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
self.parent.assertEqual(
result.entity_logits.shape, (self.batch_size, self.entity_length, self.entity_vocab_size)
)
def create_and_check_for_entity_classification(
self,
config,
@ -237,6 +285,8 @@ class LukeModelTester:
entity_token_type_ids,
entity_position_ids,
sequence_labels,
labels,
entity_labels,
entity_classification_labels,
entity_pair_classification_labels,
entity_span_classification_labels,
@ -269,6 +319,8 @@ class LukeModelTester:
entity_token_type_ids,
entity_position_ids,
sequence_labels,
labels,
entity_labels,
entity_classification_labels,
entity_pair_classification_labels,
entity_span_classification_labels,
@ -301,6 +353,8 @@ class LukeModelTester:
entity_token_type_ids,
entity_position_ids,
sequence_labels,
labels,
entity_labels,
entity_classification_labels,
entity_pair_classification_labels,
entity_span_classification_labels,
@ -341,6 +395,8 @@ class LukeModelTester:
entity_token_type_ids,
entity_position_ids,
sequence_labels,
labels,
entity_labels,
entity_classification_labels,
entity_pair_classification_labels,
entity_span_classification_labels,
@ -363,6 +419,7 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (
(
LukeModel,
LukeForMaskedLM,
LukeForEntityClassification,
LukeForEntityPairClassification,
LukeForEntitySpanClassification,
@ -396,6 +453,18 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase):
dtype=torch.long,
device=torch_device,
)
elif model_class == LukeForMaskedLM:
inputs_dict["labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length),
dtype=torch.long,
device=torch_device,
)
inputs_dict["entity_labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.entity_length),
dtype=torch.long,
device=torch_device,
)
return inputs_dict
def setUp(self):
@ -415,6 +484,10 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase):
model = LukeModel.from_pretrained(model_name)
self.assertIsNotNone(model)
def test_for_masked_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
def test_for_entity_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_entity_classification(*config_and_inputs)

View File

@ -23,7 +23,7 @@ from transformers.testing_utils import require_torch, slow
from .test_tokenization_common import TokenizerTesterMixin
class Luke(TokenizerTesterMixin, unittest.TestCase):
class LukeTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = LukeTokenizer
test_rust_tokenizer = False
from_pretrained_kwargs = {"cls_token": "<s>"}
@ -79,8 +79,8 @@ class Luke(TokenizerTesterMixin, unittest.TestCase):
encoded_sentence = tokenizer.build_inputs_with_special_tokens(text)
encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2)
assert encoded_sentence == encoded_text_from_decode
assert encoded_pair == encoded_pair_from_decode
self.assertEqual(encoded_sentence, encoded_text_from_decode)
self.assertEqual(encoded_pair, encoded_pair_from_decode)
def get_clean_sequence(self, tokenizer, max_length=20) -> Tuple[str, list]:
txt = "Beyonce lives in Los Angeles"
@ -159,6 +159,81 @@ class Luke(TokenizerTesterMixin, unittest.TestCase):
tokens_p_str, ["<s>", "A", ",", "<mask>", "ĠAllen", "N", "LP", "Ġsentence", ".", "</s>"]
)
def test_padding_entity_inputs(self):
tokenizer = self.get_tokenizer()
sentence = "Japanese is an East Asian language spoken by about 128 million people, primarily in Japan."
span = (15, 34)
pad_id = tokenizer.entity_vocab["[PAD]"]
mask_id = tokenizer.entity_vocab["[MASK]"]
encoding = tokenizer([sentence, sentence], entity_spans=[[span], [span, span]], padding=True)
self.assertEqual(encoding["entity_ids"], [[mask_id, pad_id], [mask_id, mask_id]])
# test with a sentence with no entity
encoding = tokenizer([sentence, sentence], entity_spans=[[], [span, span]], padding=True)
self.assertEqual(encoding["entity_ids"], [[pad_id, pad_id], [mask_id, mask_id]])
def test_if_tokenize_single_text_raise_error_with_invalid_inputs(self):
tokenizer = self.get_tokenizer()
sentence = "Japanese is an East Asian language spoken by about 128 million people, primarily in Japan."
spans = [(15, 34)]
entities = ["East Asian language"]
with self.assertRaises(ValueError):
tokenizer(sentence, entities=tuple(entities), entity_spans=spans)
with self.assertRaises(ValueError):
tokenizer(sentence, entities=entities, entity_spans=tuple(spans))
with self.assertRaises(ValueError):
tokenizer(sentence, entities=[0], entity_spans=spans)
with self.assertRaises(ValueError):
tokenizer(sentence, entities=entities, entity_spans=[0])
with self.assertRaises(ValueError):
tokenizer(sentence, entities=entities, entity_spans=spans + [(0, 9)])
def test_if_tokenize_entity_classification_raise_error_with_invalid_inputs(self):
tokenizer = self.get_tokenizer(task="entity_classification")
sentence = "Japanese is an East Asian language spoken by about 128 million people, primarily in Japan."
span = (15, 34)
with self.assertRaises(ValueError):
tokenizer(sentence, entity_spans=[])
with self.assertRaises(ValueError):
tokenizer(sentence, entity_spans=[span, span])
with self.assertRaises(ValueError):
tokenizer(sentence, entity_spans=[0])
def test_if_tokenize_entity_pair_classification_raise_error_with_invalid_inputs(self):
tokenizer = self.get_tokenizer(task="entity_pair_classification")
sentence = "Japanese is an East Asian language spoken by about 128 million people, primarily in Japan."
# head and tail information
with self.assertRaises(ValueError):
tokenizer(sentence, entity_spans=[])
with self.assertRaises(ValueError):
tokenizer(sentence, entity_spans=[0, 0])
def test_if_tokenize_entity_span_classification_raise_error_with_invalid_inputs(self):
tokenizer = self.get_tokenizer(task="entity_span_classification")
sentence = "Japanese is an East Asian language spoken by about 128 million people, primarily in Japan."
with self.assertRaises(ValueError):
tokenizer(sentence, entity_spans=[])
with self.assertRaises(ValueError):
tokenizer(sentence, entity_spans=[0, 0, 0])
@require_torch
class LukeTokenizerIntegrationTests(unittest.TestCase):

View File

@ -0,0 +1,666 @@
# coding=utf-8
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from typing import Tuple
from transformers.models.mluke.tokenization_mluke import MLukeTokenizer
from transformers.testing_utils import require_torch, slow
from .test_tokenization_common import TokenizerTesterMixin
class MLukeTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = MLukeTokenizer
test_rust_tokenizer = False
from_pretrained_kwargs = {"cls_token": "<s>"}
def setUp(self):
super().setUp()
self.special_tokens_map = {"entity_token_1": "<ent>", "entity_token_2": "<ent2>"}
def get_tokenizer(self, task=None, **kwargs):
kwargs.update(self.special_tokens_map)
kwargs.update({"task": task})
return self.tokenizer_class.from_pretrained("studio-ousia/mluke-base", **kwargs)
def get_input_output_texts(self, tokenizer):
input_text = "lower newer"
output_text = "lower newer"
return input_text, output_text
def test_full_tokenizer(self):
tokenizer = self.tokenizer_class.from_pretrained("studio-ousia/mluke-base")
text = "lower newer"
spm_tokens = ["▁lower", "▁new", "er"]
tokens = tokenizer.tokenize(text)
self.assertListEqual(tokens, spm_tokens)
input_tokens = tokens + [tokenizer.unk_token]
input_spm_tokens = [92319, 3525, 56, 3]
self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_spm_tokens)
def mluke_dict_integration_testing(self):
tokenizer = self.get_tokenizer()
self.assertListEqual(tokenizer.encode("Hello world!", add_special_tokens=False), [35378, 8999, 38])
self.assertListEqual(
tokenizer.encode("Hello world! cécé herlolip 418", add_special_tokens=False),
[35378, 8999, 38, 33273, 11676, 604, 365, 21392, 201, 1819],
)
@slow
def test_sequence_builders(self):
tokenizer = self.tokenizer_class.from_pretrained("studio-ousia/mluke-base")
text = tokenizer.encode("sequence builders", add_special_tokens=False)
text_2 = tokenizer.encode("multi-sequence build", add_special_tokens=False)
encoded_text_from_decode = tokenizer.encode(
"sequence builders", add_special_tokens=True, add_prefix_space=False
)
encoded_pair_from_decode = tokenizer.encode(
"sequence builders", "multi-sequence build", add_special_tokens=True, add_prefix_space=False
)
encoded_sentence = tokenizer.build_inputs_with_special_tokens(text)
encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2)
self.assertEqual(encoded_sentence, encoded_text_from_decode)
self.assertEqual(encoded_pair, encoded_pair_from_decode)
def get_clean_sequence(self, tokenizer, max_length=20) -> Tuple[str, list]:
txt = "Beyonce lives in Los Angeles"
ids = tokenizer.encode(txt, add_special_tokens=False)
return txt, ids
def test_pretokenized_inputs(self):
pass
def test_embeded_special_tokens(self):
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
with self.subTest("{} ({})".format(tokenizer.__class__.__name__, pretrained_name)):
tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs)
tokenizer_p = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs)
sentence = "A, <mask> AllenNLP sentence."
tokens_r = tokenizer_r.encode_plus(sentence, add_special_tokens=True, return_token_type_ids=True)
tokens_p = tokenizer_p.encode_plus(sentence, add_special_tokens=True, return_token_type_ids=True)
# token_type_ids should put 0 everywhere
self.assertEqual(sum(tokens_r["token_type_ids"]), sum(tokens_p["token_type_ids"]))
# token_type_ids should put 0 everywhere
self.assertEqual(sum(tokens_r["token_type_ids"]), sum(tokens_p["token_type_ids"]))
# attention_mask should put 1 everywhere, so sum over length should be 1
self.assertEqual(
sum(tokens_p["attention_mask"]) / len(tokens_p["attention_mask"]),
)
tokens_p_str = tokenizer_p.convert_ids_to_tokens(tokens_p["input_ids"])
# Rust correctly handles the space before the mask while python doesnt
self.assertSequenceEqual(tokens_p["input_ids"], [0, 250, 6, 50264, 3823, 487, 21992, 3645, 4, 2])
self.assertSequenceEqual(
tokens_p_str, ["<s>", "A", ",", "<mask>", "ĠAllen", "N", "LP", "Ġsentence", ".", "</s>"]
)
def test_padding_entity_inputs(self):
tokenizer = self.get_tokenizer()
sentence = "Japanese is an East Asian language spoken by about 128 million people, primarily in Japan."
span = (15, 34)
pad_id = tokenizer.entity_vocab["[PAD]"]
mask_id = tokenizer.entity_vocab["[MASK]"]
encoding = tokenizer([sentence, sentence], entity_spans=[[span], [span, span]], padding=True)
self.assertEqual(encoding["entity_ids"], [[mask_id, pad_id], [mask_id, mask_id]])
# test with a sentence with no entity
encoding = tokenizer([sentence, sentence], entity_spans=[[], [span, span]], padding=True)
self.assertEqual(encoding["entity_ids"], [[pad_id, pad_id], [mask_id, mask_id]])
def test_if_tokenize_single_text_raise_error_with_invalid_inputs(self):
tokenizer = self.get_tokenizer()
sentence = "ISO 639-3 uses the code fas for the dialects spoken across Iran and Afghanistan."
entities = ["en:ISO 639-3"]
spans = [(0, 9)]
with self.assertRaises(ValueError):
tokenizer(sentence, entities=tuple(entities), entity_spans=spans)
with self.assertRaises(ValueError):
tokenizer(sentence, entities=entities, entity_spans=tuple(spans))
with self.assertRaises(ValueError):
tokenizer(sentence, entities=[0], entity_spans=spans)
with self.assertRaises(ValueError):
tokenizer(sentence, entities=entities, entity_spans=[0])
with self.assertRaises(ValueError):
tokenizer(sentence, entities=entities, entity_spans=spans + [(0, 9)])
def test_if_tokenize_entity_classification_raise_error_with_invalid_inputs(self):
tokenizer = self.get_tokenizer(task="entity_classification")
sentence = "Japanese is an East Asian language spoken by about 128 million people, primarily in Japan."
span = (15, 34)
with self.assertRaises(ValueError):
tokenizer(sentence, entity_spans=[])
with self.assertRaises(ValueError):
tokenizer(sentence, entity_spans=[span, span])
with self.assertRaises(ValueError):
tokenizer(sentence, entity_spans=[0])
def test_if_tokenize_entity_pair_classification_raise_error_with_invalid_inputs(self):
tokenizer = self.get_tokenizer(task="entity_pair_classification")
sentence = "Japanese is an East Asian language spoken by about 128 million people, primarily in Japan."
# head and tail information
with self.assertRaises(ValueError):
tokenizer(sentence, entity_spans=[])
with self.assertRaises(ValueError):
tokenizer(sentence, entity_spans=[0, 0])
def test_if_tokenize_entity_span_classification_raise_error_with_invalid_inputs(self):
tokenizer = self.get_tokenizer(task="entity_span_classification")
sentence = "Japanese is an East Asian language spoken by about 128 million people, primarily in Japan."
with self.assertRaises(ValueError):
tokenizer(sentence, entity_spans=[])
with self.assertRaises(ValueError):
tokenizer(sentence, entity_spans=[0, 0, 0])
@require_torch
class MLukeTokenizerIntegrationTests(unittest.TestCase):
tokenizer_class = MLukeTokenizer
from_pretrained_kwargs = {"cls_token": "<s>"}
@classmethod
def setUpClass(cls):
cls.tokenizer = MLukeTokenizer.from_pretrained("studio-ousia/mluke-base", return_token_type_ids=True)
cls.entity_classification_tokenizer = MLukeTokenizer.from_pretrained(
"studio-ousia/mluke-base", return_token_type_ids=True, task="entity_classification"
)
cls.entity_pair_tokenizer = MLukeTokenizer.from_pretrained(
"studio-ousia/mluke-base", return_token_type_ids=True, task="entity_pair_classification"
)
cls.entity_span_tokenizer = MLukeTokenizer.from_pretrained(
"studio-ousia/mluke-base", return_token_type_ids=True, task="entity_span_classification"
)
def test_single_text_no_padding_or_truncation(self):
tokenizer = self.tokenizer
sentence = "ISO 639-3 uses the code fas for the dialects spoken across Iran and アフガニスタン (Afghanistan)."
entities = ["en:ISO 639-3", "DUMMY_ENTITY", "ja:アフガニスタン", "en:Afghanistan"]
spans = [(0, 9), (59, 63), (68, 75), (77, 88)]
encoding = tokenizer(sentence, entities=entities, entity_spans=spans, return_token_type_ids=True)
self.assertEqual(
tokenizer.decode(encoding["input_ids"], spaces_between_special_tokens=False),
"<s> ISO 639-3 uses the code fas for the dialects spoken across Iran and アフガニスタン ( Afghanistan ).</s>",
)
self.assertEqual(
tokenizer.decode(encoding["input_ids"][1:5], spaces_between_special_tokens=False), "ISO 639-3"
)
self.assertEqual(tokenizer.decode(encoding["input_ids"][17], spaces_between_special_tokens=False), "Iran")
self.assertEqual(
tokenizer.decode(encoding["input_ids"][19:25], spaces_between_special_tokens=False), "アフガニスタン"
)
self.assertEqual(
tokenizer.decode(encoding["input_ids"][26], spaces_between_special_tokens=False), "Afghanistan"
)
self.assertEqual(
encoding["entity_ids"],
[
tokenizer.entity_vocab["en:ISO 639-3"],
tokenizer.entity_vocab["[UNK]"],
tokenizer.entity_vocab["ja:アフガニスタン"],
tokenizer.entity_vocab["en:Afghanistan"],
],
)
self.assertEqual(encoding["entity_attention_mask"], [1, 1, 1, 1])
self.assertEqual(encoding["entity_token_type_ids"], [0, 0, 0, 0])
# fmt: off
self.assertEqual(
encoding["entity_position_ids"],
[
[1, 2, 3, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[17, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[19, 20, 21, 22, 23, 24, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[26, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]
]
)
# fmt: on
def test_single_text_only_entity_spans_no_padding_or_truncation(self):
tokenizer = self.tokenizer
sentence = "ISO 639-3 uses the code fas for the dialects spoken across Iran and アフガニスタン (Afghanistan)."
entities = ["en:ISO 639-3", "DUMMY_ENTITY", "ja:アフガニスタン", "en:Afghanistan"]
spans = [(0, 9), (59, 63), (68, 75), (77, 88)]
encoding = tokenizer(sentence, entities=entities, entity_spans=spans, return_token_type_ids=True)
self.assertEqual(
tokenizer.decode(encoding["input_ids"], spaces_between_special_tokens=False),
"<s> ISO 639-3 uses the code fas for the dialects spoken across Iran and アフガニスタン ( Afghanistan ).</s>",
)
self.assertEqual(
tokenizer.decode(encoding["input_ids"][1:5], spaces_between_special_tokens=False), "ISO 639-3"
)
self.assertEqual(tokenizer.decode(encoding["input_ids"][17], spaces_between_special_tokens=False), "Iran")
self.assertEqual(
tokenizer.decode(encoding["input_ids"][20:25], spaces_between_special_tokens=False), "アフガニスタン"
)
self.assertEqual(
tokenizer.decode(encoding["input_ids"][26], spaces_between_special_tokens=False), "Afghanistan"
)
self.assertEqual(
encoding["entity_ids"],
[
tokenizer.entity_vocab["en:ISO 639-3"],
tokenizer.entity_vocab["[UNK]"],
tokenizer.entity_vocab["ja:アフガニスタン"],
tokenizer.entity_vocab["en:Afghanistan"],
],
)
self.assertEqual(encoding["entity_attention_mask"], [1, 1, 1, 1])
self.assertEqual(encoding["entity_token_type_ids"], [0, 0, 0, 0])
# fmt: off
self.assertEqual(
encoding["entity_position_ids"],
[
[1, 2, 3, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[17, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[19, 20, 21, 22, 23, 24, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[26, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]
]
)
# fmt: on
def test_single_text_padding_pytorch_tensors(self):
tokenizer = self.tokenizer
sentence = "ISO 639-3 uses the code fas for the dialects spoken across Iran and アフガニスタン (Afghanistan)."
entities = ["en:ISO 639-3", "DUMMY_ENTITY", "ja:アフガニスタン", "en:Afghanistan"]
spans = [(0, 9), (59, 63), (68, 75), (77, 88)]
encoding = tokenizer(
sentence,
entities=entities,
entity_spans=spans,
return_token_type_ids=True,
padding="max_length",
max_length=30,
max_entity_length=16,
return_tensors="pt",
)
# test words
self.assertEqual(encoding["input_ids"].shape, (1, 30))
self.assertEqual(encoding["attention_mask"].shape, (1, 30))
self.assertEqual(encoding["token_type_ids"].shape, (1, 30))
# test entities
self.assertEqual(encoding["entity_ids"].shape, (1, 16))
self.assertEqual(encoding["entity_attention_mask"].shape, (1, 16))
self.assertEqual(encoding["entity_token_type_ids"].shape, (1, 16))
self.assertEqual(encoding["entity_position_ids"].shape, (1, 16, tokenizer.max_mention_length))
def test_text_pair_no_padding_or_truncation(self):
tokenizer = self.tokenizer
sentence = "ISO 639-3 uses the code fas"
sentence_pair = "for the dialects spoken across Iran and アフガニスタン (Afghanistan)."
entities = ["en:ISO 639-3"]
entities_pair = ["DUMMY_ENTITY", "ja:アフガニスタン", "en:Afghanistan"]
spans = [(0, 9)]
spans_pair = [(31, 35), (40, 47), (49, 60)]
encoding = tokenizer(
sentence,
sentence_pair,
entities=entities,
entities_pair=entities_pair,
entity_spans=spans,
entity_spans_pair=spans_pair,
return_token_type_ids=True,
)
self.assertEqual(
tokenizer.decode(encoding["input_ids"], spaces_between_special_tokens=False),
"<s> ISO 639-3 uses the code fas</s></s> for the dialects spoken across Iran and アフガニスタン ( Afghanistan ).</s>",
)
self.assertEqual(
tokenizer.decode(encoding["input_ids"][1:5], spaces_between_special_tokens=False), "ISO 639-3"
)
self.assertEqual(tokenizer.decode(encoding["input_ids"][19], spaces_between_special_tokens=False), "Iran")
self.assertEqual(
tokenizer.decode(encoding["input_ids"][21:27], spaces_between_special_tokens=False), "アフガニスタン"
)
self.assertEqual(
tokenizer.decode(encoding["input_ids"][28], spaces_between_special_tokens=False), "Afghanistan"
)
self.assertEqual(
encoding["entity_ids"],
[
tokenizer.entity_vocab["en:ISO 639-3"],
tokenizer.entity_vocab["[UNK]"],
tokenizer.entity_vocab["ja:アフガニスタン"],
tokenizer.entity_vocab["en:Afghanistan"],
],
)
self.assertEqual(encoding["entity_attention_mask"], [1, 1, 1, 1])
self.assertEqual(encoding["entity_token_type_ids"], [0, 0, 0, 0])
# fmt: off
self.assertEqual(
encoding["entity_position_ids"],
[
[1, 2, 3, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[19, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[21, 22, 23, 24, 25, 26, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[28, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]
]
)
# fmt: on
def test_text_pair_only_entity_spans_no_padding_or_truncation(self):
tokenizer = self.tokenizer
sentence = "ISO 639-3 uses the code fas"
sentence_pair = "for the dialects spoken across Iran and アフガニスタン (Afghanistan)."
entities = ["en:ISO 639-3"]
entities_pair = ["DUMMY_ENTITY", "ja:アフガニスタン", "en:Afghanistan"]
spans = [(0, 9)]
spans_pair = [(31, 35), (40, 47), (49, 60)]
encoding = tokenizer(
sentence,
sentence_pair,
entities=entities,
entities_pair=entities_pair,
entity_spans=spans,
entity_spans_pair=spans_pair,
return_token_type_ids=True,
)
self.assertEqual(
tokenizer.decode(encoding["input_ids"], spaces_between_special_tokens=False),
"<s> ISO 639-3 uses the code fas</s></s> for the dialects spoken across Iran and アフガニスタン ( Afghanistan ).</s>",
)
self.assertEqual(
tokenizer.decode(encoding["input_ids"][1:5], spaces_between_special_tokens=False), "ISO 639-3"
)
self.assertEqual(tokenizer.decode(encoding["input_ids"][19], spaces_between_special_tokens=False), "Iran")
self.assertEqual(
tokenizer.decode(encoding["input_ids"][21:27], spaces_between_special_tokens=False), "アフガニスタン"
)
self.assertEqual(
tokenizer.decode(encoding["input_ids"][28], spaces_between_special_tokens=False), "Afghanistan"
)
self.assertEqual(
encoding["entity_ids"],
[
tokenizer.entity_vocab["en:ISO 639-3"],
tokenizer.entity_vocab["[UNK]"],
tokenizer.entity_vocab["ja:アフガニスタン"],
tokenizer.entity_vocab["en:Afghanistan"],
],
)
# fmt: off
self.assertEqual(
encoding["entity_position_ids"],
[
[1, 2, 3, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[19, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[21, 22, 23, 24, 25, 26, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[28, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]
]
)
# fmt: on
def test_text_pair_padding_pytorch_tensors(self):
tokenizer = self.tokenizer
sentence = "ISO 639-3 uses the code fas"
sentence_pair = "for the dialects spoken across Iran and アフガニスタン (Afghanistan)."
entities = ["en:ISO 639-3"]
entities_pair = ["DUMMY_ENTITY", "ja:アフガニスタン", "en:Afghanistan"]
spans = [(0, 9)]
spans_pair = [(31, 35), (40, 47), (49, 60)]
encoding = tokenizer(
sentence,
sentence_pair,
entities=entities,
entities_pair=entities_pair,
entity_spans=spans,
entity_spans_pair=spans_pair,
return_token_type_ids=True,
padding="max_length",
max_length=40,
max_entity_length=16,
return_tensors="pt",
)
# test words
self.assertEqual(encoding["input_ids"].shape, (1, 40))
self.assertEqual(encoding["attention_mask"].shape, (1, 40))
self.assertEqual(encoding["token_type_ids"].shape, (1, 40))
# test entities
self.assertEqual(encoding["entity_ids"].shape, (1, 16))
self.assertEqual(encoding["entity_attention_mask"].shape, (1, 16))
self.assertEqual(encoding["entity_token_type_ids"].shape, (1, 16))
self.assertEqual(encoding["entity_position_ids"].shape, (1, 16, tokenizer.max_mention_length))
def test_entity_classification_no_padding_or_truncation(self):
tokenizer = self.entity_classification_tokenizer
sentence = "Japanese is an East Asian language spoken by about 128 million people, primarily in Japan."
span = (15, 34)
encoding = tokenizer(sentence, entity_spans=[span], return_token_type_ids=True)
# test words
self.assertEqual(len(encoding["input_ids"]), 23)
self.assertEqual(len(encoding["attention_mask"]), 23)
self.assertEqual(len(encoding["token_type_ids"]), 23)
self.assertEqual(
tokenizer.decode(encoding["input_ids"], spaces_between_special_tokens=False),
"<s> Japanese is an<ent>East Asian language<ent>spoken by about 128 million people, primarily in Japan.</s>",
)
self.assertEqual(
tokenizer.decode(encoding["input_ids"][4:9], spaces_between_special_tokens=False),
"<ent>East Asian language<ent>",
)
# test entities
mask_id = tokenizer.entity_vocab["[MASK]"]
self.assertEqual(encoding["entity_ids"], [mask_id])
self.assertEqual(encoding["entity_attention_mask"], [1])
self.assertEqual(encoding["entity_token_type_ids"], [0])
# fmt: off
self.assertEqual(
encoding["entity_position_ids"],
[[4, 5, 6, 7, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]]
)
# fmt: on
def test_entity_classification_padding_pytorch_tensors(self):
tokenizer = self.entity_classification_tokenizer
sentence = "Japanese is an East Asian language spoken by about 128 million people, primarily in Japan."
span = (15, 34)
encoding = tokenizer(
sentence, entity_spans=[span], return_token_type_ids=True, padding="max_length", return_tensors="pt"
)
# test words
self.assertEqual(encoding["input_ids"].shape, (1, 512))
self.assertEqual(encoding["attention_mask"].shape, (1, 512))
self.assertEqual(encoding["token_type_ids"].shape, (1, 512))
# test entities
self.assertEqual(encoding["entity_ids"].shape, (1, 1))
self.assertEqual(encoding["entity_attention_mask"].shape, (1, 1))
self.assertEqual(encoding["entity_token_type_ids"].shape, (1, 1))
self.assertEqual(
encoding["entity_position_ids"].shape, (1, tokenizer.max_entity_length, tokenizer.max_mention_length)
)
def test_entity_pair_classification_no_padding_or_truncation(self):
tokenizer = self.entity_pair_tokenizer
sentence = "Japanese is an East Asian language spoken by about 128 million people, primarily in Japan."
# head and tail information
spans = [(0, 8), (84, 89)]
encoding = tokenizer(sentence, entity_spans=spans, return_token_type_ids=True)
self.assertEqual(
tokenizer.decode(encoding["input_ids"], spaces_between_special_tokens=False),
"<s><ent>Japanese<ent>is an East Asian language spoken by about 128 million people, primarily in<ent2>Japan<ent2>.</s>",
)
self.assertEqual(
tokenizer.decode(encoding["input_ids"][1:4], spaces_between_special_tokens=False),
"<ent>Japanese<ent>",
)
self.assertEqual(
tokenizer.decode(encoding["input_ids"][20:23], spaces_between_special_tokens=False), "<ent2>Japan<ent2>"
)
mask_id = tokenizer.entity_vocab["[MASK]"]
mask2_id = tokenizer.entity_vocab["[MASK2]"]
self.assertEqual(encoding["entity_ids"], [mask_id, mask2_id])
self.assertEqual(encoding["entity_attention_mask"], [1, 1])
self.assertEqual(encoding["entity_token_type_ids"], [0, 0])
# fmt: off
self.assertEqual(
encoding["entity_position_ids"],
[
[1, 2, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[20, 21, 22, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]
]
)
# fmt: on
def test_entity_pair_classification_padding_pytorch_tensors(self):
tokenizer = self.entity_pair_tokenizer
sentence = "Japanese is an East Asian language spoken by about 128 million people, primarily in Japan."
# head and tail information
spans = [(0, 8), (84, 89)]
encoding = tokenizer(
sentence,
entity_spans=spans,
return_token_type_ids=True,
padding="max_length",
max_length=30,
return_tensors="pt",
)
# test words
self.assertEqual(encoding["input_ids"].shape, (1, 30))
self.assertEqual(encoding["attention_mask"].shape, (1, 30))
self.assertEqual(encoding["token_type_ids"].shape, (1, 30))
# test entities
self.assertEqual(encoding["entity_ids"].shape, (1, 2))
self.assertEqual(encoding["entity_attention_mask"].shape, (1, 2))
self.assertEqual(encoding["entity_token_type_ids"].shape, (1, 2))
self.assertEqual(
encoding["entity_position_ids"].shape, (1, tokenizer.max_entity_length, tokenizer.max_mention_length)
)
def test_entity_span_classification_no_padding_or_truncation(self):
tokenizer = self.entity_span_tokenizer
sentence = "Japanese is an East Asian language spoken by about 128 million people, primarily in Japan."
spans = [(0, 8), (15, 34), (84, 89)]
encoding = tokenizer(sentence, entity_spans=spans, return_token_type_ids=True)
self.assertEqual(
tokenizer.decode(encoding["input_ids"], spaces_between_special_tokens=False),
"<s> Japanese is an East Asian language spoken by about 128 million people, primarily in Japan.</s>",
)
mask_id = tokenizer.entity_vocab["[MASK]"]
self.assertEqual(encoding["entity_ids"], [mask_id, mask_id, mask_id])
self.assertEqual(encoding["entity_attention_mask"], [1, 1, 1])
self.assertEqual(encoding["entity_token_type_ids"], [0, 0, 0])
# fmt: off
self.assertEqual(
encoding["entity_position_ids"],
[
[1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[4, 5, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[18, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]]
)
# fmt: on
self.assertEqual(encoding["entity_start_positions"], [1, 4, 18])
self.assertEqual(encoding["entity_end_positions"], [1, 6, 18])
def test_entity_span_classification_padding_pytorch_tensors(self):
tokenizer = self.entity_span_tokenizer
sentence = "Japanese is an East Asian language spoken by about 128 million people, primarily in Japan."
spans = [(0, 8), (15, 34), (84, 89)]
encoding = tokenizer(
sentence,
entity_spans=spans,
return_token_type_ids=True,
padding="max_length",
max_length=30,
max_entity_length=16,
return_tensors="pt",
)
# test words
self.assertEqual(encoding["input_ids"].shape, (1, 30))
self.assertEqual(encoding["attention_mask"].shape, (1, 30))
self.assertEqual(encoding["token_type_ids"].shape, (1, 30))
# test entities
self.assertEqual(encoding["entity_ids"].shape, (1, 16))
self.assertEqual(encoding["entity_attention_mask"].shape, (1, 16))
self.assertEqual(encoding["entity_token_type_ids"].shape, (1, 16))
self.assertEqual(encoding["entity_position_ids"].shape, (1, 16, tokenizer.max_mention_length))
self.assertEqual(encoding["entity_start_positions"].shape, (1, 16))
self.assertEqual(encoding["entity_end_positions"].shape, (1, 16))

View File

@ -116,6 +116,7 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
"DPRReader",
"FlaubertForQuestionAnswering",
"GPT2DoubleHeadsModel",
"LukeForMaskedLM",
"LukeForEntityClassification",
"LukeForEntityPairClassification",
"LukeForEntitySpanClassification",