Add mLUKE (#14640)
* implement MLukeTokenizer and LukeForMaskedLM * update tests * update docs * add LukeForMaskedLM to check_repo.py * update README * fix test and specify the entity pad id in tokenization_(m)luke * fix EntityPredictionHeadTransform
This commit is contained in:
parent
4cdb67caba
commit
30646a0a3c
|
@ -275,6 +275,7 @@ Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih.
|
|||
1. **[LED](https://huggingface.co/docs/transformers/model_doc/led)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan.
|
||||
1. **[Longformer](https://huggingface.co/docs/transformers/model_doc/longformer)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan.
|
||||
1. **[LUKE](https://huggingface.co/docs/transformers/model_doc/luke)** (from Studio Ousia) released with the paper [LUKE: Deep Contextualized Entity Representations with Entity-aware Self-attention](https://arxiv.org/abs/2010.01057) by Ikuya Yamada, Akari Asai, Hiroyuki Shindo, Hideaki Takeda, Yuji Matsumoto.
|
||||
1. **[mLUKE](https://huggingface.co/docs/transformers/model_doc/mluke)** (from Studio Ousia) released with the paper [mLUKE: The Power of Entity Representations in Multilingual Pretrained Language Models](https://arxiv.org/abs/2110.08151) by Ryokan Ri, Ikuya Yamada, and Yoshimasa Tsuruoka.
|
||||
1. **[LXMERT](https://huggingface.co/docs/transformers/model_doc/lxmert)** (from UNC Chapel Hill) released with the paper [LXMERT: Learning Cross-Modality Encoder Representations from Transformers for Open-Domain Question Answering](https://arxiv.org/abs/1908.07490) by Hao Tan and Mohit Bansal.
|
||||
1. **[M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100)** (from Facebook) released with the paper [Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125) by Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin.
|
||||
1. **[MarianMT](https://huggingface.co/docs/transformers/model_doc/marian)** Machine translation models trained using [OPUS](http://opus.nlpl.eu/) data by Jörg Tiedemann. The [Marian Framework](https://marian-nmt.github.io/) is being developed by the Microsoft Translator Team.
|
||||
|
|
|
@ -261,6 +261,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는
|
|||
1. **[MBart-50](https://huggingface.co/docs/transformers/model_doc/mbart)** (from Facebook) released with the paper [Multilingual Translation with Extensible Multilingual Pretraining and Finetuning](https://arxiv.org/abs/2008.00401) by Yuqing Tang, Chau Tran, Xian Li, Peng-Jen Chen, Naman Goyal, Vishrav Chaudhary, Jiatao Gu, Angela Fan.
|
||||
1. **[Megatron-BERT](https://huggingface.co/docs/transformers/model_doc/megatron_bert)** (from NVIDIA) released with the paper [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) by Mohammad Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper and Bryan Catanzaro.
|
||||
1. **[Megatron-GPT2](https://huggingface.co/docs/transformers/model_doc/megatron_gpt2)** (from NVIDIA) released with the paper [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) by Mohammad Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper and Bryan Catanzaro.
|
||||
1. **[mLUKE](https://huggingface.co/docs/transformers/model_doc/mluke)** (from Studio Ousia) released with the paper [mLUKE: The Power of Entity Representations in Multilingual Pretrained Language Models](https://arxiv.org/abs/2110.08151) by Ryokan Ri, Ikuya Yamada, and Yoshimasa Tsuruoka.
|
||||
1. **[MPNet](https://huggingface.co/docs/transformers/model_doc/mpnet)** (from Microsoft Research) released with the paper [MPNet: Masked and Permuted Pre-training for Language Understanding](https://arxiv.org/abs/2004.09297) by Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu.
|
||||
1. **[MT5](https://huggingface.co/docs/transformers/model_doc/mt5)** (from Google AI) released with the paper [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934) by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel.
|
||||
1. **[Pegasus](https://huggingface.co/docs/transformers/model_doc/pegasus)** (from Google) released with the paper [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization](https://arxiv.org/abs/1912.08777) by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu.
|
||||
|
|
|
@ -285,6 +285,7 @@ conda install -c huggingface transformers
|
|||
1. **[MBart-50](https://huggingface.co/docs/transformers/model_doc/mbart)** (来自 Facebook) 伴随论文 [Multilingual Translation with Extensible Multilingual Pretraining and Finetuning](https://arxiv.org/abs/2008.00401) 由 Yuqing Tang, Chau Tran, Xian Li, Peng-Jen Chen, Naman Goyal, Vishrav Chaudhary, Jiatao Gu, Angela Fan 发布。
|
||||
1. **[Megatron-BERT](https://huggingface.co/docs/transformers/model_doc/megatron_bert)** (来自 NVIDIA) 伴随论文 [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) 由 Mohammad Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper and Bryan Catanzaro 发布。
|
||||
1. **[Megatron-GPT2](https://huggingface.co/docs/transformers/model_doc/megatron_gpt2)** (来自 NVIDIA) 伴随论文 [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) 由 Mohammad Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper and Bryan Catanzaro 发布。
|
||||
1. **[mLUKE](https://huggingface.co/docs/transformers/model_doc/mluke)** (来自 Studio Ousia) 伴随论文 [mLUKE: The Power of Entity Representations in Multilingual Pretrained Language Models](https://arxiv.org/abs/2110.08151) 由 Ryokan Ri, Ikuya Yamada, and Yoshimasa Tsuruoka 发布。
|
||||
1. **[MPNet](https://huggingface.co/docs/transformers/model_doc/mpnet)** (来自 Microsoft Research) 伴随论文 [MPNet: Masked and Permuted Pre-training for Language Understanding](https://arxiv.org/abs/2004.09297) 由 Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu 发布。
|
||||
1. **[MT5](https://huggingface.co/docs/transformers/model_doc/mt5)** (来自 Google AI) 伴随论文 [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934) 由 Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel 发布。
|
||||
1. **[Pegasus](https://huggingface.co/docs/transformers/model_doc/pegasus)** (来自 Google) 伴随论文 [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization](https://arxiv.org/abs/1912.08777) 由 Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu 发布。
|
||||
|
|
|
@ -297,6 +297,7 @@ conda install -c huggingface transformers
|
|||
1. **[MBart-50](https://huggingface.co/docs/transformers/model_doc/mbart)** (from Facebook) released with the paper [Multilingual Translation with Extensible Multilingual Pretraining and Finetuning](https://arxiv.org/abs/2008.00401) by Yuqing Tang, Chau Tran, Xian Li, Peng-Jen Chen, Naman Goyal, Vishrav Chaudhary, Jiatao Gu, Angela Fan.
|
||||
1. **[Megatron-BERT](https://huggingface.co/docs/transformers/model_doc/megatron_bert)** (from NVIDIA) released with the paper [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) by Mohammad Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper and Bryan Catanzaro.
|
||||
1. **[Megatron-GPT2](https://huggingface.co/docs/transformers/model_doc/megatron_gpt2)** (from NVIDIA) released with the paper [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) by Mohammad Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper and Bryan Catanzaro.
|
||||
1. **[mLUKE](https://huggingface.co/docs/transformers/model_doc/mluke)** (from Studio Ousia) released with the paper [mLUKE: The Power of Entity Representations in Multilingual Pretrained Language Models](https://arxiv.org/abs/2110.08151) by Ryokan Ri, Ikuya Yamada, and Yoshimasa Tsuruoka.
|
||||
1. **[MPNet](https://huggingface.co/docs/transformers/model_doc/mpnet)** (from Microsoft Research) released with the paper [MPNet: Masked and Permuted Pre-training for Language Understanding](https://arxiv.org/abs/2004.09297) by Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu.
|
||||
1. **[MT5](https://huggingface.co/docs/transformers/model_doc/mt5)** (from Google AI) released with the paper [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934) by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel.
|
||||
1. **[Pegasus](https://huggingface.co/docs/transformers/model_doc/pegasus)** (from Google) released with the paper [PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization](https://arxiv.org/abs/1912.08777) by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu.
|
||||
|
|
|
@ -135,6 +135,7 @@ conversion utilities for the following models.
|
|||
1. **[LED](model_doc/led)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan.
|
||||
1. **[Longformer](model_doc/longformer)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan.
|
||||
1. **[LUKE](model_doc/luke)** (from Studio Ousia) released with the paper [LUKE: Deep Contextualized Entity Representations with Entity-aware Self-attention](https://arxiv.org/abs/2010.01057) by Ikuya Yamada, Akari Asai, Hiroyuki Shindo, Hideaki Takeda, Yuji Matsumoto.
|
||||
1. **[mLUKE](model_doc/mluke)** (from Studio Ousia) released with the paper [mLUKE: The Power of Entity Representations in Multilingual Pretrained Language Models](https://arxiv.org/abs/2110.08151) by Ryokan Ri, Ikuya Yamada, and Yoshimasa Tsuruoka.
|
||||
1. **[LXMERT](model_doc/lxmert)** (from UNC Chapel Hill) released with the paper [LXMERT: Learning Cross-Modality Encoder Representations from Transformers for Open-Domain Question Answering](https://arxiv.org/abs/1908.07490) by Hao Tan and Mohit Bansal.
|
||||
1. **[M2M100](model_doc/m2m_100)** (from Facebook) released with the paper [Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125) by Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin.
|
||||
1. **[MarianMT](model_doc/marian)** Machine translation models trained using [OPUS](http://opus.nlpl.eu/) data by Jörg Tiedemann. The [Marian Framework](https://marian-nmt.github.io/) is being developed by the Microsoft Translator Team.
|
||||
|
|
|
@ -20,7 +20,7 @@ Rust library `tokenizers <https://github.com/huggingface/tokenizers>`__. The "Fa
|
|||
1. a significant speed-up in particular when doing batched tokenization and
|
||||
2. additional methods to map between the original string (character and words) and the token space (e.g. getting the
|
||||
index of the token comprising a given character or the span of characters corresponding to a given token). Currently
|
||||
no "Fast" implementation is available for the SentencePiece-based tokenizers (for T5, ALBERT, CamemBERT, XLMRoBERTa
|
||||
no "Fast" implementation is available for the SentencePiece-based tokenizers (for T5, ALBERT, CamemBERT, XLM-RoBERTa
|
||||
and XLNet models).
|
||||
|
||||
The base classes :class:`~transformers.PreTrainedTokenizer` and :class:`~transformers.PreTrainedTokenizerFast`
|
||||
|
|
|
@ -137,6 +137,12 @@ LukeModel
|
|||
.. autoclass:: transformers.LukeModel
|
||||
:members: forward
|
||||
|
||||
LukeForMaskedLM
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.LukeForMaskedLM
|
||||
:members: forward
|
||||
|
||||
|
||||
LukeForEntityClassification
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
|
|
@ -0,0 +1,66 @@
|
|||
..
|
||||
Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
mLUKE
|
||||
-----------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
Overview
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
The mLUKE model was proposed in `mLUKE: The Power of Entity Representations in Multilingual Pretrained Language Models
|
||||
<https://arxiv.org/abs/2110.08151>`__ by Ryokan Ri, Ikuya Yamada, and Yoshimasa Tsuruoka. It's a multilingual extension
|
||||
of the `LUKE model <https://arxiv.org/abs/2010.01057>`__ trained on the basis of XLM-RoBERTa.
|
||||
|
||||
It is based on XLM-RoBERTa and adds entity embeddings, which helps improve performance on various downstream tasks
|
||||
involving reasoning about entities such as named entity recognition, extractive question answering, relation
|
||||
classification, cloze-style knowledge completion.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*Recent studies have shown that multilingual pretrained language models can be effectively improved with cross-lingual
|
||||
alignment information from Wikipedia entities. However, existing methods only exploit entity information in pretraining
|
||||
and do not explicitly use entities in downstream tasks. In this study, we explore the effectiveness of leveraging
|
||||
entity representations for downstream cross-lingual tasks. We train a multilingual language model with 24 languages
|
||||
with entity representations and show the model consistently outperforms word-based pretrained models in various
|
||||
cross-lingual transfer tasks. We also analyze the model and the key insight is that incorporating entity
|
||||
representations into the input allows us to extract more language-agnostic features. We also evaluate the model with a
|
||||
multilingual cloze prompt task with the mLAMA dataset. We show that entity-based prompt elicits correct factual
|
||||
knowledge more likely than using only word representations.*
|
||||
|
||||
One can directly plug in the weights of mLUKE into a LUKE model, like so:
|
||||
|
||||
.. code-block::
|
||||
|
||||
from transformers import LukeModel
|
||||
|
||||
model = LukeModel.from_pretrained('studio-ousia/mluke-base')
|
||||
|
||||
Note that mLUKE has its own tokenizer, :class:`~transformers.MLukeTokenizer`. You can initialize it as follows:
|
||||
|
||||
.. code-block::
|
||||
|
||||
from transformers import MLukeTokenizer
|
||||
|
||||
tokenizer = MLukeTokenizer.from_pretrained('studio-ousia/mluke-base')
|
||||
|
||||
|
||||
As mLUKE's architecture is equivalent to that of LUKE, one can refer to :doc:`LUKE's documentation page <luke>` for all
|
||||
tips, code examples and notebooks.
|
||||
|
||||
This model was contributed by `ryo0634 <https://huggingface.co/ryo0634>`__. The original code can be found `here
|
||||
<https://github.com/studio-ousia/luke>`__.
|
||||
|
||||
MLukeTokenizer
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.MLukeTokenizer
|
||||
:members: __call__, save_vocabulary
|
|
@ -17,8 +17,6 @@ Most of the models available in this library are mono-lingual models (English, C
|
|||
models are available and have a different mechanisms than mono-lingual models. This page details the usage of these
|
||||
models.
|
||||
|
||||
The two models that currently support multiple languages are BERT and XLM.
|
||||
|
||||
XLM
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
|
@ -127,3 +125,17 @@ Two XLM-RoBERTa checkpoints can be used for multi-lingual tasks:
|
|||
|
||||
- ``xlm-roberta-base`` (Masked language modeling, 100 languages)
|
||||
- ``xlm-roberta-large`` (Masked language modeling, 100 languages)
|
||||
|
||||
mLUKE
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
mLUKE is based on XLM-RoBERTa and further trained on Wikipedia articles in 24 languages with masked language modeling
|
||||
as well as masked entity prediction objective.
|
||||
|
||||
The model can be used in the same way as other models solely based on word-piece inputs, but also can be used with
|
||||
entity representations to achieve further performance gain, with entity-related tasks such as relation extraction,
|
||||
named entity recognition and question answering (see :doc:`LUKE <model_doc/luke>`).
|
||||
|
||||
Currently, one mLUKE checkpoint is available:
|
||||
|
||||
- ``studio-ousia/mluke-base`` (Masked language modeling + Masked entity prediction, 100 languages)
|
||||
|
|
|
@ -245,6 +245,7 @@ _import_structure = {
|
|||
"models.mbart": ["MBartConfig"],
|
||||
"models.mbart50": [],
|
||||
"models.megatron_bert": ["MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MegatronBertConfig"],
|
||||
"models.mluke": [],
|
||||
"models.mmbt": ["MMBTConfig"],
|
||||
"models.mobilebert": ["MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MobileBertConfig", "MobileBertTokenizer"],
|
||||
"models.mpnet": ["MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "MPNetConfig", "MPNetTokenizer"],
|
||||
|
@ -379,6 +380,7 @@ if is_sentencepiece_available():
|
|||
_import_structure["models.marian"].append("MarianTokenizer")
|
||||
_import_structure["models.mbart"].append("MBartTokenizer")
|
||||
_import_structure["models.mbart50"].append("MBart50Tokenizer")
|
||||
_import_structure["models.mluke"].append("MLukeTokenizer")
|
||||
_import_structure["models.mt5"].append("MT5Tokenizer")
|
||||
_import_structure["models.pegasus"].append("PegasusTokenizer")
|
||||
_import_structure["models.reformer"].append("ReformerTokenizer")
|
||||
|
@ -1037,6 +1039,7 @@ if is_torch_available():
|
|||
"LukeForEntityClassification",
|
||||
"LukeForEntityPairClassification",
|
||||
"LukeForEntitySpanClassification",
|
||||
"LukeForMaskedLM",
|
||||
"LukeModel",
|
||||
"LukePreTrainedModel",
|
||||
]
|
||||
|
@ -2368,6 +2371,7 @@ if TYPE_CHECKING:
|
|||
from .models.m2m_100 import M2M100Tokenizer
|
||||
from .models.marian import MarianTokenizer
|
||||
from .models.mbart import MBart50Tokenizer, MBartTokenizer
|
||||
from .models.mluke import MLukeTokenizer
|
||||
from .models.mt5 import MT5Tokenizer
|
||||
from .models.pegasus import PegasusTokenizer
|
||||
from .models.reformer import ReformerTokenizer
|
||||
|
@ -2904,6 +2908,7 @@ if TYPE_CHECKING:
|
|||
LukeForEntityClassification,
|
||||
LukeForEntityPairClassification,
|
||||
LukeForEntitySpanClassification,
|
||||
LukeForMaskedLM,
|
||||
LukeModel,
|
||||
LukePreTrainedModel,
|
||||
)
|
||||
|
|
|
@ -71,6 +71,7 @@ from . import (
|
|||
mbart50,
|
||||
megatron_bert,
|
||||
megatron_gpt2,
|
||||
mluke,
|
||||
mmbt,
|
||||
mobilebert,
|
||||
mpnet,
|
||||
|
|
|
@ -178,6 +178,7 @@ else:
|
|||
("hubert", ("Wav2Vec2CTCTokenizer", None)),
|
||||
("gpt_neo", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
|
||||
("luke", ("LukeTokenizer", None)),
|
||||
("mluke", ("MLukeTokenizer" if is_sentencepiece_available() else None, None)),
|
||||
("bigbird_pegasus", ("PegasusTokenizer", "PegasusTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("canine", ("CanineTokenizer", None)),
|
||||
("bertweet", ("BertweetTokenizer", None)),
|
||||
|
|
|
@ -32,6 +32,7 @@ if is_torch_available():
|
|||
"LukeForEntityClassification",
|
||||
"LukeForEntityPairClassification",
|
||||
"LukeForEntitySpanClassification",
|
||||
"LukeForMaskedLM",
|
||||
"LukeModel",
|
||||
"LukePreTrainedModel",
|
||||
]
|
||||
|
@ -47,6 +48,7 @@ if TYPE_CHECKING:
|
|||
LukeForEntityClassification,
|
||||
LukeForEntityPairClassification,
|
||||
LukeForEntitySpanClassification,
|
||||
LukeForMaskedLM,
|
||||
LukeModel,
|
||||
LukePreTrainedModel,
|
||||
)
|
||||
|
|
|
@ -22,7 +22,7 @@ import torch
|
|||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...activations import ACT2FN, gelu
|
||||
from ...file_utils import (
|
||||
ModelOutput,
|
||||
add_start_docstrings,
|
||||
|
@ -110,6 +110,49 @@ class BaseLukeModelOutput(BaseModelOutput):
|
|||
entity_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class LukeMaskedLMOutput(ModelOutput):
|
||||
"""
|
||||
Base class for model's outputs, with potential hidden states and attentions.
|
||||
|
||||
Args:
|
||||
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
|
||||
The sum of masked language modeling (MLM) loss and entity prediction loss.
|
||||
mlm_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
|
||||
Masked language modeling (MLM) loss.
|
||||
mep_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
|
||||
Masked entity prediction (MEP) loss.
|
||||
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
entity_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
|
||||
Prediction scores of the entity prediction head (scores for each entity vocabulary token before SoftMax).
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
entity_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output
|
||||
of each layer plus the initial entity embedding outputs.
|
||||
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||
sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
"""
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
mlm_loss: Optional[torch.FloatTensor] = None
|
||||
mep_loss: Optional[torch.FloatTensor] = None
|
||||
logits: torch.FloatTensor = None
|
||||
entity_logits: torch.FloatTensor = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
entity_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class EntityClassificationOutput(ModelOutput):
|
||||
"""
|
||||
|
@ -674,6 +717,38 @@ class LukePooler(nn.Module):
|
|||
return pooled_output
|
||||
|
||||
|
||||
class EntityPredictionHeadTransform(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.entity_emb_size)
|
||||
if isinstance(config.hidden_act, str):
|
||||
self.transform_act_fn = ACT2FN[config.hidden_act]
|
||||
else:
|
||||
self.transform_act_fn = config.hidden_act
|
||||
self.LayerNorm = nn.LayerNorm(config.entity_emb_size, eps=config.layer_norm_eps)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.transform_act_fn(hidden_states)
|
||||
hidden_states = self.LayerNorm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class EntityPredictionHead(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.transform = EntityPredictionHeadTransform(config)
|
||||
self.decoder = nn.Linear(config.entity_emb_size, config.entity_vocab_size, bias=False)
|
||||
self.bias = nn.Parameter(torch.zeros(config.entity_vocab_size))
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.transform(hidden_states)
|
||||
hidden_states = self.decoder(hidden_states) + self.bias
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LukePreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
|
@ -1013,6 +1088,170 @@ def create_position_ids_from_input_ids(input_ids, padding_idx):
|
|||
return incremental_indices.long() + padding_idx
|
||||
|
||||
|
||||
# Copied from transformers.models.roberta.modeling_roberta.RobertaLMHead
|
||||
class LukeLMHead(nn.Module):
|
||||
"""Roberta Head for masked language modeling."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
|
||||
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def forward(self, features, **kwargs):
|
||||
x = self.dense(features)
|
||||
x = gelu(x)
|
||||
x = self.layer_norm(x)
|
||||
|
||||
# project back to size of vocabulary with bias
|
||||
x = self.decoder(x)
|
||||
|
||||
return x
|
||||
|
||||
def _tie_weights(self):
|
||||
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
|
||||
self.bias = self.decoder.bias
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
The LUKE model with a language modeling head and entity prediction head on top for masked language modeling and
|
||||
masked entity prediction.
|
||||
""",
|
||||
LUKE_START_DOCSTRING,
|
||||
)
|
||||
class LukeForMaskedLM(LukePreTrainedModel):
|
||||
_keys_to_ignore_on_save = [
|
||||
r"lm_head.decoder.weight",
|
||||
r"lm_head.decoder.bias",
|
||||
r"entity_predictions.decoder.weight",
|
||||
]
|
||||
_keys_to_ignore_on_load_missing = [
|
||||
r"position_ids",
|
||||
r"lm_head.decoder.weight",
|
||||
r"lm_head.decoder.bias",
|
||||
r"entity_predictions.decoder.weight",
|
||||
]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
self.luke = LukeModel(config)
|
||||
|
||||
self.lm_head = LukeLMHead(config)
|
||||
self.entity_predictions = EntityPredictionHead(config)
|
||||
|
||||
self.loss_fn = nn.CrossEntropyLoss(ignore_index=-1)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def tie_weights(self):
|
||||
super().tie_weights()
|
||||
self._tie_or_clone_weights(self.entity_predictions.decoder, self.luke.entity_embeddings.entity_embeddings)
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head.decoder
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head.decoder = new_embeddings
|
||||
|
||||
@add_start_docstrings_to_model_forward(LUKE_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@replace_return_docstrings(output_type=LukeMaskedLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
entity_ids=None,
|
||||
entity_attention_mask=None,
|
||||
entity_token_type_ids=None,
|
||||
entity_position_ids=None,
|
||||
labels=None,
|
||||
entity_labels=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
|
||||
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
|
||||
entity_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, entity_length)`, `optional`):
|
||||
Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
|
||||
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.luke(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
entity_ids=entity_ids,
|
||||
entity_attention_mask=entity_attention_mask,
|
||||
entity_token_type_ids=entity_token_type_ids,
|
||||
entity_position_ids=entity_position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
loss = None
|
||||
|
||||
mlm_loss = None
|
||||
logits = self.lm_head(outputs.last_hidden_state)
|
||||
if labels is not None:
|
||||
mlm_loss = self.loss_fn(logits.view(-1, self.config.vocab_size), labels.view(-1))
|
||||
if loss is None:
|
||||
loss = mlm_loss
|
||||
|
||||
mep_loss = None
|
||||
entity_logits = self.entity_predictions(outputs.entity_last_hidden_state)
|
||||
if entity_labels is not None:
|
||||
mep_loss = self.loss_fn(entity_logits.view(-1, self.config.entity_vocab_size), entity_labels.view(-1))
|
||||
if loss is None:
|
||||
loss = mep_loss
|
||||
else:
|
||||
loss = loss + mep_loss
|
||||
|
||||
if not return_dict:
|
||||
output = (logits, entity_logits, outputs.hidden_states, outputs.entity_hidden_states, outputs.attentions)
|
||||
if mlm_loss is not None and mep_loss is not None:
|
||||
return (loss, mlm_loss, mep_loss) + output
|
||||
elif mlm_loss is not None:
|
||||
return (loss, mlm_loss) + output
|
||||
elif mep_loss is not None:
|
||||
return (loss, mep_loss) + output
|
||||
else:
|
||||
return output
|
||||
|
||||
return LukeMaskedLMOutput(
|
||||
loss=loss,
|
||||
mlm_loss=mlm_loss,
|
||||
mep_loss=mep_loss,
|
||||
logits=logits,
|
||||
entity_logits=entity_logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
entity_hidden_states=outputs.entity_hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
The LUKE model with a classification head on top (a linear layer on top of the hidden state of the first entity
|
||||
|
|
|
@ -312,17 +312,15 @@ class LukeTokenizer(RobertaTokenizer):
|
|||
# Input type checking for clearer error
|
||||
is_valid_single_text = isinstance(text, str)
|
||||
is_valid_batch_text = isinstance(text, (list, tuple)) and (len(text) == 0 or (isinstance(text[0], str)))
|
||||
assert (
|
||||
is_valid_single_text or is_valid_batch_text
|
||||
), "text input must be of type `str` (single example) or `List[str]` (batch)."
|
||||
if not (is_valid_single_text or is_valid_batch_text):
|
||||
raise ValueError("text input must be of type `str` (single example) or `List[str]` (batch).")
|
||||
|
||||
is_valid_single_text_pair = isinstance(text_pair, str)
|
||||
is_valid_batch_text_pair = isinstance(text_pair, (list, tuple)) and (
|
||||
len(text_pair) == 0 or isinstance(text_pair[0], str)
|
||||
)
|
||||
assert (
|
||||
text_pair is None or is_valid_single_text_pair or is_valid_batch_text_pair
|
||||
), "text_pair input must be of type `str` (single example) or `List[str]` (batch)."
|
||||
if not (text_pair is None or is_valid_single_text_pair or is_valid_batch_text_pair):
|
||||
raise ValueError("text_pair input must be of type `str` (single example) or `List[str]` (batch).")
|
||||
|
||||
is_batched = bool(isinstance(text, (list, tuple)))
|
||||
|
||||
|
@ -391,105 +389,6 @@ class LukeTokenizer(RobertaTokenizer):
|
|||
**kwargs,
|
||||
)
|
||||
|
||||
@add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
|
||||
def encode_plus(
|
||||
self,
|
||||
text: Union[TextInput],
|
||||
text_pair: Optional[Union[TextInput]] = None,
|
||||
entity_spans: Optional[EntitySpanInput] = None,
|
||||
entity_spans_pair: Optional[EntitySpanInput] = None,
|
||||
entities: Optional[EntityInput] = None,
|
||||
entities_pair: Optional[EntityInput] = None,
|
||||
add_special_tokens: bool = True,
|
||||
padding: Union[bool, str, PaddingStrategy] = False,
|
||||
truncation: Union[bool, str, TruncationStrategy] = False,
|
||||
max_length: Optional[int] = None,
|
||||
max_entity_length: Optional[int] = None,
|
||||
stride: int = 0,
|
||||
is_split_into_words: Optional[bool] = False,
|
||||
pad_to_multiple_of: Optional[int] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
return_token_type_ids: Optional[bool] = None,
|
||||
return_attention_mask: Optional[bool] = None,
|
||||
return_overflowing_tokens: bool = False,
|
||||
return_special_tokens_mask: bool = False,
|
||||
return_offsets_mapping: bool = False,
|
||||
return_length: bool = False,
|
||||
verbose: bool = True,
|
||||
**kwargs
|
||||
) -> BatchEncoding:
|
||||
"""
|
||||
Tokenize and prepare for the model a sequence or a pair of sequences.
|
||||
|
||||
.. warning:: This method is deprecated, ``__call__`` should be used instead.
|
||||
|
||||
Args:
|
||||
text (:obj:`str`):
|
||||
The first sequence to be encoded. Each sequence must be a string.
|
||||
text_pair (:obj:`str`):
|
||||
The second sequence to be encoded. Each sequence must be a string.
|
||||
entity_spans (:obj:`List[Tuple[int, int]]`, :obj:`List[List[Tuple[int, int]]]`, `optional`)::
|
||||
The first sequence of entity spans to be encoded. The sequence consists of tuples each with two
|
||||
integers denoting character-based start and end positions of entities. If you specify
|
||||
:obj:`"entity_classification"` or :obj:`"entity_pair_classification"` as the ``task`` argument in the
|
||||
constructor, the length of each sequence must be 1 or 2, respectively. If you specify ``entities``, the
|
||||
length of the sequence must be equal to the length of ``entities``.
|
||||
entity_spans_pair (:obj:`List[Tuple[int, int]]`, :obj:`List[List[Tuple[int, int]]]`, `optional`)::
|
||||
The second sequence of entity spans to be encoded. The sequence consists of tuples each with two
|
||||
integers denoting character-based start and end positions of entities. If you specify the ``task``
|
||||
argument in the constructor, this argument is ignored. If you specify ``entities_pair``, the length of
|
||||
the sequence must be equal to the length of ``entities_pair``.
|
||||
entities (:obj:`List[str]` `optional`)::
|
||||
The first sequence of entities to be encoded. The sequence consists of strings representing entities,
|
||||
i.e., special entities (e.g., [MASK]) or entity titles of Wikipedia (e.g., Los Angeles). This argument
|
||||
is ignored if you specify the ``task`` argument in the constructor. The length of the sequence must be
|
||||
equal to the length of ``entity_spans``. If you specify ``entity_spans`` without specifying this
|
||||
argument, the entity sequence is automatically constructed by filling it with the [MASK] entity.
|
||||
entities_pair (:obj:`List[str]`, :obj:`List[List[str]]`, `optional`)::
|
||||
The second sequence of entities to be encoded. The sequence consists of strings representing entities,
|
||||
i.e., special entities (e.g., [MASK]) or entity titles of Wikipedia (e.g., Los Angeles). This argument
|
||||
is ignored if you specify the ``task`` argument in the constructor. The length of the sequence must be
|
||||
equal to the length of ``entity_spans_pair``. If you specify ``entity_spans_pair`` without specifying
|
||||
this argument, the entity sequence is automatically constructed by filling it with the [MASK] entity.
|
||||
max_entity_length (:obj:`int`, `optional`):
|
||||
The maximum length of the entity sequence.
|
||||
"""
|
||||
# Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
|
||||
padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
max_length=max_length,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
verbose=verbose,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return self._encode_plus(
|
||||
text=text,
|
||||
text_pair=text_pair,
|
||||
entity_spans=entity_spans,
|
||||
entity_spans_pair=entity_spans_pair,
|
||||
entities=entities,
|
||||
entities_pair=entities_pair,
|
||||
add_special_tokens=add_special_tokens,
|
||||
padding_strategy=padding_strategy,
|
||||
truncation_strategy=truncation_strategy,
|
||||
max_length=max_length,
|
||||
max_entity_length=max_entity_length,
|
||||
stride=stride,
|
||||
is_split_into_words=is_split_into_words,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
return_tensors=return_tensors,
|
||||
return_token_type_ids=return_token_type_ids,
|
||||
return_attention_mask=return_attention_mask,
|
||||
return_overflowing_tokens=return_overflowing_tokens,
|
||||
return_special_tokens_mask=return_special_tokens_mask,
|
||||
return_offsets_mapping=return_offsets_mapping,
|
||||
return_length=return_length,
|
||||
verbose=verbose,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _encode_plus(
|
||||
self,
|
||||
text: Union[TextInput],
|
||||
|
@ -571,89 +470,6 @@ class LukeTokenizer(RobertaTokenizer):
|
|||
verbose=verbose,
|
||||
)
|
||||
|
||||
@add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
|
||||
def batch_encode_plus(
|
||||
self,
|
||||
batch_text_or_text_pairs: Union[List[TextInput], List[TextInputPair]],
|
||||
batch_entity_spans_or_entity_spans_pairs: Optional[
|
||||
Union[List[EntitySpanInput], List[Tuple[EntitySpanInput, EntitySpanInput]]]
|
||||
] = None,
|
||||
batch_entities_or_entities_pairs: Optional[
|
||||
Union[List[EntityInput], List[Tuple[EntityInput, EntityInput]]]
|
||||
] = None,
|
||||
add_special_tokens: bool = True,
|
||||
padding: Union[bool, str, PaddingStrategy] = False,
|
||||
truncation: Union[bool, str, TruncationStrategy] = False,
|
||||
max_length: Optional[int] = None,
|
||||
max_entity_length: Optional[int] = None,
|
||||
stride: int = 0,
|
||||
is_split_into_words: Optional[bool] = False,
|
||||
pad_to_multiple_of: Optional[int] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
return_token_type_ids: Optional[bool] = None,
|
||||
return_attention_mask: Optional[bool] = None,
|
||||
return_overflowing_tokens: bool = False,
|
||||
return_special_tokens_mask: bool = False,
|
||||
return_offsets_mapping: bool = False,
|
||||
return_length: bool = False,
|
||||
verbose: bool = True,
|
||||
**kwargs
|
||||
) -> BatchEncoding:
|
||||
"""
|
||||
Tokenize and prepare for the model a list of sequences or a list of pairs of sequences.
|
||||
|
||||
.. warning::
|
||||
This method is deprecated, ``__call__`` should be used instead.
|
||||
|
||||
|
||||
Args:
|
||||
batch_text_or_text_pairs (:obj:`List[str]`, :obj:`List[Tuple[str, str]]`):
|
||||
Batch of sequences or pair of sequences to be encoded. This can be a list of string or a list of pair
|
||||
of string (see details in ``encode_plus``).
|
||||
batch_entity_spans_or_entity_spans_pairs (:obj:`List[List[Tuple[int, int]]]`,
|
||||
:obj:`List[Tuple[List[Tuple[int, int]], List[Tuple[int, int]]]]`, `optional`)::
|
||||
Batch of entity span sequences or pairs of entity span sequences to be encoded (see details in
|
||||
``encode_plus``).
|
||||
batch_entities_or_entities_pairs (:obj:`List[List[str]]`, :obj:`List[Tuple[List[str], List[str]]]`,
|
||||
`optional`):
|
||||
Batch of entity sequences or pairs of entity sequences to be encoded (see details in ``encode_plus``).
|
||||
max_entity_length (:obj:`int`, `optional`):
|
||||
The maximum length of the entity sequence.
|
||||
"""
|
||||
|
||||
# Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
|
||||
padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
max_length=max_length,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
verbose=verbose,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return self._batch_encode_plus(
|
||||
batch_text_or_text_pairs=batch_text_or_text_pairs,
|
||||
batch_entity_spans_or_entity_spans_pairs=batch_entity_spans_or_entity_spans_pairs,
|
||||
batch_entities_or_entities_pairs=batch_entities_or_entities_pairs,
|
||||
add_special_tokens=add_special_tokens,
|
||||
padding_strategy=padding_strategy,
|
||||
truncation_strategy=truncation_strategy,
|
||||
max_length=max_length,
|
||||
max_entity_length=max_entity_length,
|
||||
stride=stride,
|
||||
is_split_into_words=is_split_into_words,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
return_tensors=return_tensors,
|
||||
return_token_type_ids=return_token_type_ids,
|
||||
return_attention_mask=return_attention_mask,
|
||||
return_overflowing_tokens=return_overflowing_tokens,
|
||||
return_special_tokens_mask=return_special_tokens_mask,
|
||||
return_offsets_mapping=return_offsets_mapping,
|
||||
return_length=return_length,
|
||||
verbose=verbose,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _batch_encode_plus(
|
||||
self,
|
||||
batch_text_or_text_pairs: Union[List[TextInput], List[TextInputPair]],
|
||||
|
@ -713,11 +529,12 @@ class LukeTokenizer(RobertaTokenizer):
|
|||
entity_spans, entity_spans_pair = None, None
|
||||
if batch_entity_spans_or_entity_spans_pairs is not None:
|
||||
entity_spans_or_entity_spans_pairs = batch_entity_spans_or_entity_spans_pairs[index]
|
||||
if entity_spans_or_entity_spans_pairs:
|
||||
if isinstance(entity_spans_or_entity_spans_pairs[0][0], int):
|
||||
entity_spans, entity_spans_pair = entity_spans_or_entity_spans_pairs, None
|
||||
else:
|
||||
entity_spans, entity_spans_pair = entity_spans_or_entity_spans_pairs
|
||||
if len(entity_spans_or_entity_spans_pairs) > 0 and isinstance(
|
||||
entity_spans_or_entity_spans_pairs[0], list
|
||||
):
|
||||
entity_spans, entity_spans_pair = entity_spans_or_entity_spans_pairs
|
||||
else:
|
||||
entity_spans, entity_spans_pair = entity_spans_or_entity_spans_pairs, None
|
||||
|
||||
(
|
||||
first_ids,
|
||||
|
@ -761,6 +578,25 @@ class LukeTokenizer(RobertaTokenizer):
|
|||
|
||||
return BatchEncoding(batch_outputs)
|
||||
|
||||
def _check_entity_input_format(self, entities: Optional[EntityInput], entity_spans: Optional[EntitySpanInput]):
|
||||
if not isinstance(entity_spans, list):
|
||||
raise ValueError("entity_spans should be given as a list")
|
||||
elif len(entity_spans) > 0 and not isinstance(entity_spans[0], tuple):
|
||||
raise ValueError(
|
||||
"entity_spans should be given as a list of tuples " "containing the start and end character indices"
|
||||
)
|
||||
|
||||
if entities is not None:
|
||||
|
||||
if not isinstance(entities, list):
|
||||
raise ValueError("If you specify entities, they should be given as a list")
|
||||
|
||||
if len(entities) > 0 and not isinstance(entities[0], str):
|
||||
raise ValueError("If you specify entities, they should be given as a list of entity names")
|
||||
|
||||
if len(entities) != len(entity_spans):
|
||||
raise ValueError("If you specify entities, entities and entity_spans must be the same length")
|
||||
|
||||
def _create_input_sequence(
|
||||
self,
|
||||
text: Union[TextInput],
|
||||
|
@ -816,15 +652,7 @@ class LukeTokenizer(RobertaTokenizer):
|
|||
if entity_spans is None:
|
||||
first_ids = get_input_ids(text)
|
||||
else:
|
||||
assert isinstance(entity_spans, list) and (
|
||||
len(entity_spans) == 0 or isinstance(entity_spans[0], tuple)
|
||||
), "entity_spans should be given as a list of tuples containing the start and end character indices"
|
||||
assert entities is None or (
|
||||
isinstance(entities, list) and (len(entities) == 0 or isinstance(entities[0], str))
|
||||
), "If you specify entities, they should be given as a list of entity names"
|
||||
assert entities is None or len(entities) == len(
|
||||
entity_spans
|
||||
), "If you specify entities, entities and entity_spans must be the same length"
|
||||
self._check_entity_input_format(entities, entity_spans)
|
||||
|
||||
first_ids, first_entity_token_spans = get_input_ids_and_entity_token_spans(text, entity_spans)
|
||||
if entities is None:
|
||||
|
@ -836,16 +664,7 @@ class LukeTokenizer(RobertaTokenizer):
|
|||
if entity_spans_pair is None:
|
||||
second_ids = get_input_ids(text_pair)
|
||||
else:
|
||||
assert isinstance(entity_spans_pair, list) and (
|
||||
len(entity_spans_pair) == 0 or isinstance(entity_spans_pair[0], tuple)
|
||||
), "entity_spans_pair should be given as a list of tuples containing the start and end character indices"
|
||||
assert entities_pair is None or (
|
||||
isinstance(entities_pair, list)
|
||||
and (len(entities_pair) == 0 or isinstance(entities_pair[0], str))
|
||||
), "If you specify entities_pair, they should be given as a list of entity names"
|
||||
assert entities_pair is None or len(entities_pair) == len(
|
||||
entity_spans_pair
|
||||
), "If you specify entities_pair, entities_pair and entity_spans_pair must be the same length"
|
||||
self._check_entity_input_format(entities_pair, entity_spans_pair)
|
||||
|
||||
second_ids, second_entity_token_spans = get_input_ids_and_entity_token_spans(
|
||||
text_pair, entity_spans_pair
|
||||
|
@ -856,10 +675,11 @@ class LukeTokenizer(RobertaTokenizer):
|
|||
second_entity_ids = [self.entity_vocab.get(entity, unk_entity_id) for entity in entities_pair]
|
||||
|
||||
elif self.task == "entity_classification":
|
||||
assert (
|
||||
isinstance(entity_spans, list) and len(entity_spans) == 1 and isinstance(entity_spans[0], tuple)
|
||||
), "Entity spans should be a list containing a single tuple containing the start and end character indices of an entity"
|
||||
|
||||
if not (isinstance(entity_spans, list) and len(entity_spans) == 1 and isinstance(entity_spans[0], tuple)):
|
||||
raise ValueError(
|
||||
"Entity spans should be a list containing a single tuple "
|
||||
"containing the start and end character indices of an entity"
|
||||
)
|
||||
first_entity_ids = [self.entity_vocab["[MASK]"]]
|
||||
first_ids, first_entity_token_spans = get_input_ids_and_entity_token_spans(text, entity_spans)
|
||||
|
||||
|
@ -876,12 +696,16 @@ class LukeTokenizer(RobertaTokenizer):
|
|||
first_entity_token_spans = [(entity_token_start, entity_token_end + 2)]
|
||||
|
||||
elif self.task == "entity_pair_classification":
|
||||
assert (
|
||||
if not (
|
||||
isinstance(entity_spans, list)
|
||||
and len(entity_spans) == 2
|
||||
and isinstance(entity_spans[0], tuple)
|
||||
and isinstance(entity_spans[1], tuple)
|
||||
), "Entity spans should be provided as a list of tuples, each tuple containing the start and end character indices of an entity"
|
||||
):
|
||||
raise ValueError(
|
||||
"Entity spans should be provided as a list of two tuples, "
|
||||
"each tuple containing the start and end character indices of an entity"
|
||||
)
|
||||
|
||||
head_span, tail_span = entity_spans
|
||||
first_entity_ids = [self.entity_vocab["[MASK]"], self.entity_vocab["[MASK2]"]]
|
||||
|
@ -907,9 +731,11 @@ class LukeTokenizer(RobertaTokenizer):
|
|||
elif self.task == "entity_span_classification":
|
||||
mask_entity_id = self.entity_vocab["[MASK]"]
|
||||
|
||||
assert isinstance(entity_spans, list) and isinstance(
|
||||
entity_spans[0], tuple
|
||||
), "Entity spans should be provided as a list of tuples, each tuple containing the start and end character indices of an entity"
|
||||
if not (isinstance(entity_spans, list) and len(entity_spans) > 0 and isinstance(entity_spans[0], tuple)):
|
||||
raise ValueError(
|
||||
"Entity spans should be provided as a list of tuples, "
|
||||
"each tuple containing the start and end character indices of an entity"
|
||||
)
|
||||
|
||||
first_ids, first_entity_token_spans = get_input_ids_and_entity_token_spans(text, entity_spans)
|
||||
first_entity_ids = [mask_entity_id] * len(entity_spans)
|
||||
|
@ -1218,7 +1044,6 @@ class LukeTokenizer(RobertaTokenizer):
|
|||
self._eventual_warn_about_too_long_sequence(encoded_inputs["input_ids"], max_length, verbose)
|
||||
|
||||
# Padding
|
||||
# To do: add padding of entities
|
||||
if padding_strategy != PaddingStrategy.DO_NOT_PAD or return_attention_mask:
|
||||
encoded_inputs = self.pad(
|
||||
encoded_inputs,
|
||||
|
@ -1369,9 +1194,8 @@ class LukeTokenizer(RobertaTokenizer):
|
|||
return BatchEncoding(encoded_inputs, tensor_type=return_tensors)
|
||||
|
||||
batch_size = len(required_input)
|
||||
assert all(
|
||||
len(v) == batch_size for v in encoded_inputs.values()
|
||||
), "Some items in the output dictionary have a different batch size than others."
|
||||
if any(len(v) != batch_size for v in encoded_inputs.values()):
|
||||
raise ValueError("Some items in the output dictionary have a different batch size than others.")
|
||||
|
||||
if padding_strategy == PaddingStrategy.LONGEST:
|
||||
max_length = max(len(inputs) for inputs in required_input)
|
||||
|
@ -1487,7 +1311,9 @@ class LukeTokenizer(RobertaTokenizer):
|
|||
encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference
|
||||
encoded_inputs["input_ids"] = encoded_inputs["input_ids"] + [self.pad_token_id] * difference
|
||||
if entities_provided:
|
||||
encoded_inputs["entity_ids"] = encoded_inputs["entity_ids"] + [0] * entity_difference
|
||||
encoded_inputs["entity_ids"] = (
|
||||
encoded_inputs["entity_ids"] + [self.entity_vocab["[PAD]"]] * entity_difference
|
||||
)
|
||||
encoded_inputs["entity_position_ids"] = (
|
||||
encoded_inputs["entity_position_ids"] + [[-1] * self.max_mention_length] * entity_difference
|
||||
)
|
||||
|
@ -1516,7 +1342,9 @@ class LukeTokenizer(RobertaTokenizer):
|
|||
encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
|
||||
encoded_inputs["input_ids"] = [self.pad_token_id] * difference + encoded_inputs["input_ids"]
|
||||
if entities_provided:
|
||||
encoded_inputs["entity_ids"] = [0] * entity_difference + encoded_inputs["entity_ids"]
|
||||
encoded_inputs["entity_ids"] = [self.entity_vocab["[PAD]"]] * entity_difference + encoded_inputs[
|
||||
"entity_ids"
|
||||
]
|
||||
encoded_inputs["entity_position_ids"] = [
|
||||
[-1] * self.max_mention_length
|
||||
] * entity_difference + encoded_inputs["entity_position_ids"]
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
# flake8: noqa
|
||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||
# module, but to preserve other warnings. So, don't check this module at all.
|
||||
|
||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...file_utils import _LazyModule, is_sentencepiece_available
|
||||
|
||||
|
||||
_import_structure = {}
|
||||
|
||||
|
||||
if is_sentencepiece_available():
|
||||
_import_structure["tokenization_mluke"] = ["MLukeTokenizer"]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
if is_sentencepiece_available():
|
||||
from .tokenization_mluke import MLukeTokenizer
|
||||
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
|
|
@ -0,0 +1,228 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert mLUKE checkpoint."""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import LukeConfig, LukeForMaskedLM, MLukeTokenizer, XLMRobertaTokenizer
|
||||
from transformers.tokenization_utils_base import AddedToken
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_luke_checkpoint(checkpoint_path, metadata_path, entity_vocab_path, pytorch_dump_folder_path, model_size):
|
||||
# Load configuration defined in the metadata file
|
||||
with open(metadata_path) as metadata_file:
|
||||
metadata = json.load(metadata_file)
|
||||
config = LukeConfig(use_entity_aware_attention=True, **metadata["model_config"])
|
||||
|
||||
# Load in the weights from the checkpoint_path
|
||||
state_dict = torch.load(checkpoint_path, map_location="cpu")["module"]
|
||||
|
||||
# Load the entity vocab file
|
||||
entity_vocab = load_original_entity_vocab(entity_vocab_path)
|
||||
# add an entry for [MASK2]
|
||||
entity_vocab["[MASK2]"] = max(entity_vocab.values()) + 1
|
||||
config.entity_vocab_size += 1
|
||||
|
||||
tokenizer = XLMRobertaTokenizer.from_pretrained(metadata["model_config"]["bert_model_name"])
|
||||
|
||||
# Add special tokens to the token vocabulary for downstream tasks
|
||||
entity_token_1 = AddedToken("<ent>", lstrip=False, rstrip=False)
|
||||
entity_token_2 = AddedToken("<ent2>", lstrip=False, rstrip=False)
|
||||
tokenizer.add_special_tokens(dict(additional_special_tokens=[entity_token_1, entity_token_2]))
|
||||
config.vocab_size += 2
|
||||
|
||||
print(f"Saving tokenizer to {pytorch_dump_folder_path}")
|
||||
tokenizer.save_pretrained(pytorch_dump_folder_path)
|
||||
with open(os.path.join(pytorch_dump_folder_path, "tokenizer_config.json"), "r") as f:
|
||||
tokenizer_config = json.load(f)
|
||||
tokenizer_config["tokenizer_class"] = "MLukeTokenizer"
|
||||
with open(os.path.join(pytorch_dump_folder_path, "tokenizer_config.json"), "w") as f:
|
||||
json.dump(tokenizer_config, f)
|
||||
|
||||
with open(os.path.join(pytorch_dump_folder_path, MLukeTokenizer.vocab_files_names["entity_vocab_file"]), "w") as f:
|
||||
json.dump(entity_vocab, f)
|
||||
|
||||
tokenizer = MLukeTokenizer.from_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
# Initialize the embeddings of the special tokens
|
||||
ent_init_index = tokenizer.convert_tokens_to_ids(["@"])[0]
|
||||
ent2_init_index = tokenizer.convert_tokens_to_ids(["#"])[0]
|
||||
|
||||
word_emb = state_dict["embeddings.word_embeddings.weight"]
|
||||
ent_emb = word_emb[ent_init_index].unsqueeze(0)
|
||||
ent2_emb = word_emb[ent2_init_index].unsqueeze(0)
|
||||
state_dict["embeddings.word_embeddings.weight"] = torch.cat([word_emb, ent_emb, ent2_emb])
|
||||
# add special tokens for 'entity_predictions.bias'
|
||||
for bias_name in ["lm_head.decoder.bias", "lm_head.bias"]:
|
||||
decoder_bias = state_dict[bias_name]
|
||||
ent_decoder_bias = decoder_bias[ent_init_index].unsqueeze(0)
|
||||
ent2_decoder_bias = decoder_bias[ent2_init_index].unsqueeze(0)
|
||||
state_dict[bias_name] = torch.cat([decoder_bias, ent_decoder_bias, ent2_decoder_bias])
|
||||
|
||||
# Initialize the query layers of the entity-aware self-attention mechanism
|
||||
for layer_index in range(config.num_hidden_layers):
|
||||
for matrix_name in ["query.weight", "query.bias"]:
|
||||
prefix = f"encoder.layer.{layer_index}.attention.self."
|
||||
state_dict[prefix + "w2e_" + matrix_name] = state_dict[prefix + matrix_name]
|
||||
state_dict[prefix + "e2w_" + matrix_name] = state_dict[prefix + matrix_name]
|
||||
state_dict[prefix + "e2e_" + matrix_name] = state_dict[prefix + matrix_name]
|
||||
|
||||
# Initialize the embedding of the [MASK2] entity using that of the [MASK] entity for downstream tasks
|
||||
entity_emb = state_dict["entity_embeddings.entity_embeddings.weight"]
|
||||
entity_mask_emb = entity_emb[entity_vocab["[MASK]"]].unsqueeze(0)
|
||||
state_dict["entity_embeddings.entity_embeddings.weight"] = torch.cat([entity_emb, entity_mask_emb])
|
||||
# add [MASK2] for 'entity_predictions.bias'
|
||||
entity_prediction_bias = state_dict["entity_predictions.bias"]
|
||||
entity_mask_bias = entity_prediction_bias[entity_vocab["[MASK]"]].unsqueeze(0)
|
||||
state_dict["entity_predictions.bias"] = torch.cat([entity_prediction_bias, entity_mask_bias])
|
||||
|
||||
model = LukeForMaskedLM(config=config).eval()
|
||||
|
||||
state_dict.pop("entity_predictions.decoder.weight")
|
||||
state_dict.pop("lm_head.decoder.weight")
|
||||
state_dict.pop("lm_head.decoder.bias")
|
||||
state_dict_for_hugging_face = OrderedDict()
|
||||
for key, value in state_dict.items():
|
||||
if not (key.startswith("lm_head") or key.startswith("entity_predictions")):
|
||||
state_dict_for_hugging_face[f"luke.{key}"] = state_dict[key]
|
||||
else:
|
||||
state_dict_for_hugging_face[key] = state_dict[key]
|
||||
|
||||
missing_keys, unexpected_keys = model.load_state_dict(state_dict_for_hugging_face, strict=False)
|
||||
|
||||
if set(unexpected_keys) != {"luke.embeddings.position_ids"}:
|
||||
raise ValueError(f"Unexpected unexpected_keys: {unexpected_keys}")
|
||||
if set(missing_keys) != {
|
||||
"lm_head.decoder.weight",
|
||||
"lm_head.decoder.bias",
|
||||
"entity_predictions.decoder.weight",
|
||||
}:
|
||||
raise ValueError(f"Unexpected missing_keys: {missing_keys}")
|
||||
|
||||
model.tie_weights()
|
||||
assert (model.luke.embeddings.word_embeddings.weight == model.lm_head.decoder.weight).all()
|
||||
assert (model.luke.entity_embeddings.entity_embeddings.weight == model.entity_predictions.decoder.weight).all()
|
||||
|
||||
# Check outputs
|
||||
tokenizer = MLukeTokenizer.from_pretrained(pytorch_dump_folder_path, task="entity_classification")
|
||||
|
||||
text = "ISO 639-3 uses the code fas for the dialects spoken across Iran and アフガニスタン (Afghanistan)."
|
||||
span = (0, 9)
|
||||
encoding = tokenizer(text, entity_spans=[span], return_tensors="pt")
|
||||
|
||||
outputs = model(**encoding)
|
||||
|
||||
# Verify word hidden states
|
||||
if model_size == "large":
|
||||
raise NotImplementedError
|
||||
else: # base
|
||||
expected_shape = torch.Size((1, 33, 768))
|
||||
expected_slice = torch.tensor([[0.0892, 0.0596, -0.2819], [0.0134, 0.1199, 0.0573], [-0.0169, 0.0927, 0.0644]])
|
||||
|
||||
if not (outputs.last_hidden_state.shape == expected_shape):
|
||||
raise ValueError(
|
||||
f"Outputs.last_hidden_state.shape is {outputs.last_hidden_state.shape}, Expected shape is {expected_shape}"
|
||||
)
|
||||
if not torch.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4):
|
||||
raise ValueError
|
||||
|
||||
# Verify entity hidden states
|
||||
if model_size == "large":
|
||||
raise NotImplementedError
|
||||
else: # base
|
||||
expected_shape = torch.Size((1, 1, 768))
|
||||
expected_slice = torch.tensor([[-0.1482, 0.0609, 0.0322]])
|
||||
|
||||
if not (outputs.entity_last_hidden_state.shape == expected_shape):
|
||||
raise ValueError(
|
||||
f"Outputs.entity_last_hidden_state.shape is {outputs.entity_last_hidden_state.shape}, Expected shape is {expected_shape}"
|
||||
)
|
||||
if not torch.allclose(outputs.entity_last_hidden_state[0, :3, :3], expected_slice, atol=1e-4):
|
||||
raise ValueError
|
||||
|
||||
# Verify masked word/entity prediction
|
||||
tokenizer = MLukeTokenizer.from_pretrained(pytorch_dump_folder_path)
|
||||
text = "Tokyo is the capital of <mask>."
|
||||
span = (24, 30)
|
||||
encoding = tokenizer(text, entity_spans=[span], return_tensors="pt")
|
||||
|
||||
outputs = model(**encoding)
|
||||
|
||||
input_ids = encoding["input_ids"][0].tolist()
|
||||
mask_position_id = input_ids.index(tokenizer.convert_tokens_to_ids("<mask>"))
|
||||
predicted_id = outputs.logits[0][mask_position_id].argmax(dim=-1)
|
||||
assert "Japan" == tokenizer.decode(predicted_id)
|
||||
|
||||
predicted_entity_id = outputs.entity_logits[0][0].argmax().item()
|
||||
multilingual_predicted_entities = [
|
||||
entity for entity, entity_id in tokenizer.entity_vocab.items() if entity_id == predicted_entity_id
|
||||
]
|
||||
assert [e for e in multilingual_predicted_entities if e.startswith("en:")][0] == "en:Japan"
|
||||
|
||||
# Finally, save our PyTorch model and tokenizer
|
||||
print("Saving PyTorch model to {}".format(pytorch_dump_folder_path))
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
|
||||
def load_original_entity_vocab(entity_vocab_path):
|
||||
SPECIAL_TOKENS = ["[MASK]", "[PAD]", "[UNK]"]
|
||||
|
||||
data = [json.loads(line) for line in open(entity_vocab_path)]
|
||||
|
||||
new_mapping = {}
|
||||
for entry in data:
|
||||
entity_id = entry["id"]
|
||||
for entity_name, language in entry["entities"]:
|
||||
if entity_name in SPECIAL_TOKENS:
|
||||
new_mapping[entity_name] = entity_id
|
||||
break
|
||||
new_entity_name = f"{language}:{entity_name}"
|
||||
new_mapping[new_entity_name] = entity_id
|
||||
return new_mapping
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument("--checkpoint_path", type=str, help="Path to a pytorch_model.bin file.")
|
||||
parser.add_argument(
|
||||
"--metadata_path", default=None, type=str, help="Path to a metadata.json file, defining the configuration."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--entity_vocab_path",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Path to an entity_vocab.tsv file, containing the entity vocabulary.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path", default=None, type=str, help="Path to where to dump the output PyTorch model."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_size", default="base", type=str, choices=["base", "large"], help="Size of the model to be converted."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_luke_checkpoint(
|
||||
args.checkpoint_path,
|
||||
args.metadata_path,
|
||||
args.entity_vocab_path,
|
||||
args.pytorch_dump_folder_path,
|
||||
args.model_size,
|
||||
)
|
File diff suppressed because it is too large
Load Diff
|
@ -3020,6 +3020,18 @@ class LukeForEntitySpanClassification:
|
|||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class LukeForMaskedLM:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class LukeModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
|
|
@ -110,6 +110,15 @@ class MBartTokenizer:
|
|||
requires_backends(cls, ["sentencepiece"])
|
||||
|
||||
|
||||
class MLukeTokenizer:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["sentencepiece"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["sentencepiece"])
|
||||
|
||||
|
||||
class MT5Tokenizer:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["sentencepiece"])
|
||||
|
|
|
@ -29,6 +29,7 @@ if is_torch_available():
|
|||
LukeForEntityClassification,
|
||||
LukeForEntityPairClassification,
|
||||
LukeForEntitySpanClassification,
|
||||
LukeForMaskedLM,
|
||||
LukeModel,
|
||||
LukeTokenizer,
|
||||
)
|
||||
|
@ -138,12 +139,17 @@ class LukeModelTester:
|
|||
)
|
||||
|
||||
sequence_labels = None
|
||||
labels = None
|
||||
entity_labels = None
|
||||
entity_classification_labels = None
|
||||
entity_pair_classification_labels = None
|
||||
entity_span_classification_labels = None
|
||||
|
||||
if self.use_labels:
|
||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||
labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
entity_labels = ids_tensor([self.batch_size, self.entity_length], self.entity_vocab_size)
|
||||
|
||||
entity_classification_labels = ids_tensor([self.batch_size], self.num_entity_classification_labels)
|
||||
entity_pair_classification_labels = ids_tensor(
|
||||
[self.batch_size], self.num_entity_pair_classification_labels
|
||||
|
@ -164,6 +170,8 @@ class LukeModelTester:
|
|||
entity_token_type_ids,
|
||||
entity_position_ids,
|
||||
sequence_labels,
|
||||
labels,
|
||||
entity_labels,
|
||||
entity_classification_labels,
|
||||
entity_pair_classification_labels,
|
||||
entity_span_classification_labels,
|
||||
|
@ -199,6 +207,8 @@ class LukeModelTester:
|
|||
entity_token_type_ids,
|
||||
entity_position_ids,
|
||||
sequence_labels,
|
||||
labels,
|
||||
entity_labels,
|
||||
entity_classification_labels,
|
||||
entity_pair_classification_labels,
|
||||
entity_span_classification_labels,
|
||||
|
@ -226,6 +236,44 @@ class LukeModelTester:
|
|||
result = model(input_ids)
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
|
||||
def create_and_check_for_masked_lm(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
entity_ids,
|
||||
entity_attention_mask,
|
||||
entity_token_type_ids,
|
||||
entity_position_ids,
|
||||
sequence_labels,
|
||||
labels,
|
||||
entity_labels,
|
||||
entity_classification_labels,
|
||||
entity_pair_classification_labels,
|
||||
entity_span_classification_labels,
|
||||
):
|
||||
config.num_labels = self.num_entity_classification_labels
|
||||
model = LukeForMaskedLM(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
result = model(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
entity_ids=entity_ids,
|
||||
entity_attention_mask=entity_attention_mask,
|
||||
entity_token_type_ids=entity_token_type_ids,
|
||||
entity_position_ids=entity_position_ids,
|
||||
labels=labels,
|
||||
entity_labels=entity_labels,
|
||||
)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
self.parent.assertEqual(
|
||||
result.entity_logits.shape, (self.batch_size, self.entity_length, self.entity_vocab_size)
|
||||
)
|
||||
|
||||
def create_and_check_for_entity_classification(
|
||||
self,
|
||||
config,
|
||||
|
@ -237,6 +285,8 @@ class LukeModelTester:
|
|||
entity_token_type_ids,
|
||||
entity_position_ids,
|
||||
sequence_labels,
|
||||
labels,
|
||||
entity_labels,
|
||||
entity_classification_labels,
|
||||
entity_pair_classification_labels,
|
||||
entity_span_classification_labels,
|
||||
|
@ -269,6 +319,8 @@ class LukeModelTester:
|
|||
entity_token_type_ids,
|
||||
entity_position_ids,
|
||||
sequence_labels,
|
||||
labels,
|
||||
entity_labels,
|
||||
entity_classification_labels,
|
||||
entity_pair_classification_labels,
|
||||
entity_span_classification_labels,
|
||||
|
@ -301,6 +353,8 @@ class LukeModelTester:
|
|||
entity_token_type_ids,
|
||||
entity_position_ids,
|
||||
sequence_labels,
|
||||
labels,
|
||||
entity_labels,
|
||||
entity_classification_labels,
|
||||
entity_pair_classification_labels,
|
||||
entity_span_classification_labels,
|
||||
|
@ -341,6 +395,8 @@ class LukeModelTester:
|
|||
entity_token_type_ids,
|
||||
entity_position_ids,
|
||||
sequence_labels,
|
||||
labels,
|
||||
entity_labels,
|
||||
entity_classification_labels,
|
||||
entity_pair_classification_labels,
|
||||
entity_span_classification_labels,
|
||||
|
@ -363,6 +419,7 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase):
|
|||
all_model_classes = (
|
||||
(
|
||||
LukeModel,
|
||||
LukeForMaskedLM,
|
||||
LukeForEntityClassification,
|
||||
LukeForEntityPairClassification,
|
||||
LukeForEntitySpanClassification,
|
||||
|
@ -396,6 +453,18 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase):
|
|||
dtype=torch.long,
|
||||
device=torch_device,
|
||||
)
|
||||
elif model_class == LukeForMaskedLM:
|
||||
inputs_dict["labels"] = torch.zeros(
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length),
|
||||
dtype=torch.long,
|
||||
device=torch_device,
|
||||
)
|
||||
inputs_dict["entity_labels"] = torch.zeros(
|
||||
(self.model_tester.batch_size, self.model_tester.entity_length),
|
||||
dtype=torch.long,
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
return inputs_dict
|
||||
|
||||
def setUp(self):
|
||||
|
@ -415,6 +484,10 @@ class LukeModelTest(ModelTesterMixin, unittest.TestCase):
|
|||
model = LukeModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
def test_for_masked_lm(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
|
||||
|
||||
def test_for_entity_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_entity_classification(*config_and_inputs)
|
||||
|
|
|
@ -23,7 +23,7 @@ from transformers.testing_utils import require_torch, slow
|
|||
from .test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
|
||||
class Luke(TokenizerTesterMixin, unittest.TestCase):
|
||||
class LukeTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
tokenizer_class = LukeTokenizer
|
||||
test_rust_tokenizer = False
|
||||
from_pretrained_kwargs = {"cls_token": "<s>"}
|
||||
|
@ -79,8 +79,8 @@ class Luke(TokenizerTesterMixin, unittest.TestCase):
|
|||
encoded_sentence = tokenizer.build_inputs_with_special_tokens(text)
|
||||
encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2)
|
||||
|
||||
assert encoded_sentence == encoded_text_from_decode
|
||||
assert encoded_pair == encoded_pair_from_decode
|
||||
self.assertEqual(encoded_sentence, encoded_text_from_decode)
|
||||
self.assertEqual(encoded_pair, encoded_pair_from_decode)
|
||||
|
||||
def get_clean_sequence(self, tokenizer, max_length=20) -> Tuple[str, list]:
|
||||
txt = "Beyonce lives in Los Angeles"
|
||||
|
@ -159,6 +159,81 @@ class Luke(TokenizerTesterMixin, unittest.TestCase):
|
|||
tokens_p_str, ["<s>", "A", ",", "<mask>", "ĠAllen", "N", "LP", "Ġsentence", ".", "</s>"]
|
||||
)
|
||||
|
||||
def test_padding_entity_inputs(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
sentence = "Japanese is an East Asian language spoken by about 128 million people, primarily in Japan."
|
||||
span = (15, 34)
|
||||
pad_id = tokenizer.entity_vocab["[PAD]"]
|
||||
mask_id = tokenizer.entity_vocab["[MASK]"]
|
||||
|
||||
encoding = tokenizer([sentence, sentence], entity_spans=[[span], [span, span]], padding=True)
|
||||
self.assertEqual(encoding["entity_ids"], [[mask_id, pad_id], [mask_id, mask_id]])
|
||||
|
||||
# test with a sentence with no entity
|
||||
encoding = tokenizer([sentence, sentence], entity_spans=[[], [span, span]], padding=True)
|
||||
self.assertEqual(encoding["entity_ids"], [[pad_id, pad_id], [mask_id, mask_id]])
|
||||
|
||||
def test_if_tokenize_single_text_raise_error_with_invalid_inputs(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
sentence = "Japanese is an East Asian language spoken by about 128 million people, primarily in Japan."
|
||||
spans = [(15, 34)]
|
||||
entities = ["East Asian language"]
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
tokenizer(sentence, entities=tuple(entities), entity_spans=spans)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
tokenizer(sentence, entities=entities, entity_spans=tuple(spans))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
tokenizer(sentence, entities=[0], entity_spans=spans)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
tokenizer(sentence, entities=entities, entity_spans=[0])
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
tokenizer(sentence, entities=entities, entity_spans=spans + [(0, 9)])
|
||||
|
||||
def test_if_tokenize_entity_classification_raise_error_with_invalid_inputs(self):
|
||||
tokenizer = self.get_tokenizer(task="entity_classification")
|
||||
|
||||
sentence = "Japanese is an East Asian language spoken by about 128 million people, primarily in Japan."
|
||||
span = (15, 34)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
tokenizer(sentence, entity_spans=[])
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
tokenizer(sentence, entity_spans=[span, span])
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
tokenizer(sentence, entity_spans=[0])
|
||||
|
||||
def test_if_tokenize_entity_pair_classification_raise_error_with_invalid_inputs(self):
|
||||
tokenizer = self.get_tokenizer(task="entity_pair_classification")
|
||||
|
||||
sentence = "Japanese is an East Asian language spoken by about 128 million people, primarily in Japan."
|
||||
# head and tail information
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
tokenizer(sentence, entity_spans=[])
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
tokenizer(sentence, entity_spans=[0, 0])
|
||||
|
||||
def test_if_tokenize_entity_span_classification_raise_error_with_invalid_inputs(self):
|
||||
tokenizer = self.get_tokenizer(task="entity_span_classification")
|
||||
|
||||
sentence = "Japanese is an East Asian language spoken by about 128 million people, primarily in Japan."
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
tokenizer(sentence, entity_spans=[])
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
tokenizer(sentence, entity_spans=[0, 0, 0])
|
||||
|
||||
|
||||
@require_torch
|
||||
class LukeTokenizerIntegrationTests(unittest.TestCase):
|
||||
|
|
|
@ -0,0 +1,666 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import unittest
|
||||
from typing import Tuple
|
||||
|
||||
from transformers.models.mluke.tokenization_mluke import MLukeTokenizer
|
||||
from transformers.testing_utils import require_torch, slow
|
||||
|
||||
from .test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
|
||||
class MLukeTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
tokenizer_class = MLukeTokenizer
|
||||
test_rust_tokenizer = False
|
||||
from_pretrained_kwargs = {"cls_token": "<s>"}
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
self.special_tokens_map = {"entity_token_1": "<ent>", "entity_token_2": "<ent2>"}
|
||||
|
||||
def get_tokenizer(self, task=None, **kwargs):
|
||||
kwargs.update(self.special_tokens_map)
|
||||
kwargs.update({"task": task})
|
||||
return self.tokenizer_class.from_pretrained("studio-ousia/mluke-base", **kwargs)
|
||||
|
||||
def get_input_output_texts(self, tokenizer):
|
||||
input_text = "lower newer"
|
||||
output_text = "lower newer"
|
||||
return input_text, output_text
|
||||
|
||||
def test_full_tokenizer(self):
|
||||
tokenizer = self.tokenizer_class.from_pretrained("studio-ousia/mluke-base")
|
||||
text = "lower newer"
|
||||
spm_tokens = ["▁lower", "▁new", "er"]
|
||||
tokens = tokenizer.tokenize(text)
|
||||
self.assertListEqual(tokens, spm_tokens)
|
||||
|
||||
input_tokens = tokens + [tokenizer.unk_token]
|
||||
input_spm_tokens = [92319, 3525, 56, 3]
|
||||
self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_spm_tokens)
|
||||
|
||||
def mluke_dict_integration_testing(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
self.assertListEqual(tokenizer.encode("Hello world!", add_special_tokens=False), [35378, 8999, 38])
|
||||
self.assertListEqual(
|
||||
tokenizer.encode("Hello world! cécé herlolip 418", add_special_tokens=False),
|
||||
[35378, 8999, 38, 33273, 11676, 604, 365, 21392, 201, 1819],
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_sequence_builders(self):
|
||||
tokenizer = self.tokenizer_class.from_pretrained("studio-ousia/mluke-base")
|
||||
|
||||
text = tokenizer.encode("sequence builders", add_special_tokens=False)
|
||||
text_2 = tokenizer.encode("multi-sequence build", add_special_tokens=False)
|
||||
|
||||
encoded_text_from_decode = tokenizer.encode(
|
||||
"sequence builders", add_special_tokens=True, add_prefix_space=False
|
||||
)
|
||||
encoded_pair_from_decode = tokenizer.encode(
|
||||
"sequence builders", "multi-sequence build", add_special_tokens=True, add_prefix_space=False
|
||||
)
|
||||
|
||||
encoded_sentence = tokenizer.build_inputs_with_special_tokens(text)
|
||||
encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2)
|
||||
|
||||
self.assertEqual(encoded_sentence, encoded_text_from_decode)
|
||||
self.assertEqual(encoded_pair, encoded_pair_from_decode)
|
||||
|
||||
def get_clean_sequence(self, tokenizer, max_length=20) -> Tuple[str, list]:
|
||||
txt = "Beyonce lives in Los Angeles"
|
||||
ids = tokenizer.encode(txt, add_special_tokens=False)
|
||||
return txt, ids
|
||||
|
||||
def test_pretokenized_inputs(self):
|
||||
pass
|
||||
|
||||
def test_embeded_special_tokens(self):
|
||||
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
|
||||
with self.subTest("{} ({})".format(tokenizer.__class__.__name__, pretrained_name)):
|
||||
tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs)
|
||||
tokenizer_p = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs)
|
||||
sentence = "A, <mask> AllenNLP sentence."
|
||||
tokens_r = tokenizer_r.encode_plus(sentence, add_special_tokens=True, return_token_type_ids=True)
|
||||
tokens_p = tokenizer_p.encode_plus(sentence, add_special_tokens=True, return_token_type_ids=True)
|
||||
|
||||
# token_type_ids should put 0 everywhere
|
||||
self.assertEqual(sum(tokens_r["token_type_ids"]), sum(tokens_p["token_type_ids"]))
|
||||
|
||||
# token_type_ids should put 0 everywhere
|
||||
self.assertEqual(sum(tokens_r["token_type_ids"]), sum(tokens_p["token_type_ids"]))
|
||||
|
||||
# attention_mask should put 1 everywhere, so sum over length should be 1
|
||||
self.assertEqual(
|
||||
sum(tokens_p["attention_mask"]) / len(tokens_p["attention_mask"]),
|
||||
)
|
||||
|
||||
tokens_p_str = tokenizer_p.convert_ids_to_tokens(tokens_p["input_ids"])
|
||||
|
||||
# Rust correctly handles the space before the mask while python doesnt
|
||||
self.assertSequenceEqual(tokens_p["input_ids"], [0, 250, 6, 50264, 3823, 487, 21992, 3645, 4, 2])
|
||||
|
||||
self.assertSequenceEqual(
|
||||
tokens_p_str, ["<s>", "A", ",", "<mask>", "ĠAllen", "N", "LP", "Ġsentence", ".", "</s>"]
|
||||
)
|
||||
|
||||
def test_padding_entity_inputs(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
sentence = "Japanese is an East Asian language spoken by about 128 million people, primarily in Japan."
|
||||
span = (15, 34)
|
||||
pad_id = tokenizer.entity_vocab["[PAD]"]
|
||||
mask_id = tokenizer.entity_vocab["[MASK]"]
|
||||
|
||||
encoding = tokenizer([sentence, sentence], entity_spans=[[span], [span, span]], padding=True)
|
||||
self.assertEqual(encoding["entity_ids"], [[mask_id, pad_id], [mask_id, mask_id]])
|
||||
|
||||
# test with a sentence with no entity
|
||||
encoding = tokenizer([sentence, sentence], entity_spans=[[], [span, span]], padding=True)
|
||||
self.assertEqual(encoding["entity_ids"], [[pad_id, pad_id], [mask_id, mask_id]])
|
||||
|
||||
def test_if_tokenize_single_text_raise_error_with_invalid_inputs(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
sentence = "ISO 639-3 uses the code fas for the dialects spoken across Iran and Afghanistan."
|
||||
entities = ["en:ISO 639-3"]
|
||||
spans = [(0, 9)]
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
tokenizer(sentence, entities=tuple(entities), entity_spans=spans)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
tokenizer(sentence, entities=entities, entity_spans=tuple(spans))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
tokenizer(sentence, entities=[0], entity_spans=spans)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
tokenizer(sentence, entities=entities, entity_spans=[0])
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
tokenizer(sentence, entities=entities, entity_spans=spans + [(0, 9)])
|
||||
|
||||
def test_if_tokenize_entity_classification_raise_error_with_invalid_inputs(self):
|
||||
tokenizer = self.get_tokenizer(task="entity_classification")
|
||||
|
||||
sentence = "Japanese is an East Asian language spoken by about 128 million people, primarily in Japan."
|
||||
span = (15, 34)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
tokenizer(sentence, entity_spans=[])
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
tokenizer(sentence, entity_spans=[span, span])
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
tokenizer(sentence, entity_spans=[0])
|
||||
|
||||
def test_if_tokenize_entity_pair_classification_raise_error_with_invalid_inputs(self):
|
||||
tokenizer = self.get_tokenizer(task="entity_pair_classification")
|
||||
|
||||
sentence = "Japanese is an East Asian language spoken by about 128 million people, primarily in Japan."
|
||||
# head and tail information
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
tokenizer(sentence, entity_spans=[])
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
tokenizer(sentence, entity_spans=[0, 0])
|
||||
|
||||
def test_if_tokenize_entity_span_classification_raise_error_with_invalid_inputs(self):
|
||||
tokenizer = self.get_tokenizer(task="entity_span_classification")
|
||||
|
||||
sentence = "Japanese is an East Asian language spoken by about 128 million people, primarily in Japan."
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
tokenizer(sentence, entity_spans=[])
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
tokenizer(sentence, entity_spans=[0, 0, 0])
|
||||
|
||||
|
||||
@require_torch
|
||||
class MLukeTokenizerIntegrationTests(unittest.TestCase):
|
||||
tokenizer_class = MLukeTokenizer
|
||||
from_pretrained_kwargs = {"cls_token": "<s>"}
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.tokenizer = MLukeTokenizer.from_pretrained("studio-ousia/mluke-base", return_token_type_ids=True)
|
||||
cls.entity_classification_tokenizer = MLukeTokenizer.from_pretrained(
|
||||
"studio-ousia/mluke-base", return_token_type_ids=True, task="entity_classification"
|
||||
)
|
||||
cls.entity_pair_tokenizer = MLukeTokenizer.from_pretrained(
|
||||
"studio-ousia/mluke-base", return_token_type_ids=True, task="entity_pair_classification"
|
||||
)
|
||||
|
||||
cls.entity_span_tokenizer = MLukeTokenizer.from_pretrained(
|
||||
"studio-ousia/mluke-base", return_token_type_ids=True, task="entity_span_classification"
|
||||
)
|
||||
|
||||
def test_single_text_no_padding_or_truncation(self):
|
||||
tokenizer = self.tokenizer
|
||||
sentence = "ISO 639-3 uses the code fas for the dialects spoken across Iran and アフガニスタン (Afghanistan)."
|
||||
entities = ["en:ISO 639-3", "DUMMY_ENTITY", "ja:アフガニスタン", "en:Afghanistan"]
|
||||
spans = [(0, 9), (59, 63), (68, 75), (77, 88)]
|
||||
|
||||
encoding = tokenizer(sentence, entities=entities, entity_spans=spans, return_token_type_ids=True)
|
||||
|
||||
self.assertEqual(
|
||||
tokenizer.decode(encoding["input_ids"], spaces_between_special_tokens=False),
|
||||
"<s> ISO 639-3 uses the code fas for the dialects spoken across Iran and アフガニスタン ( Afghanistan ).</s>",
|
||||
)
|
||||
self.assertEqual(
|
||||
tokenizer.decode(encoding["input_ids"][1:5], spaces_between_special_tokens=False), "ISO 639-3"
|
||||
)
|
||||
self.assertEqual(tokenizer.decode(encoding["input_ids"][17], spaces_between_special_tokens=False), "Iran")
|
||||
self.assertEqual(
|
||||
tokenizer.decode(encoding["input_ids"][19:25], spaces_between_special_tokens=False), "アフガニスタン"
|
||||
)
|
||||
self.assertEqual(
|
||||
tokenizer.decode(encoding["input_ids"][26], spaces_between_special_tokens=False), "Afghanistan"
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
encoding["entity_ids"],
|
||||
[
|
||||
tokenizer.entity_vocab["en:ISO 639-3"],
|
||||
tokenizer.entity_vocab["[UNK]"],
|
||||
tokenizer.entity_vocab["ja:アフガニスタン"],
|
||||
tokenizer.entity_vocab["en:Afghanistan"],
|
||||
],
|
||||
)
|
||||
self.assertEqual(encoding["entity_attention_mask"], [1, 1, 1, 1])
|
||||
self.assertEqual(encoding["entity_token_type_ids"], [0, 0, 0, 0])
|
||||
# fmt: off
|
||||
self.assertEqual(
|
||||
encoding["entity_position_ids"],
|
||||
[
|
||||
[1, 2, 3, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[17, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[19, 20, 21, 22, 23, 24, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[26, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]
|
||||
]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
def test_single_text_only_entity_spans_no_padding_or_truncation(self):
|
||||
tokenizer = self.tokenizer
|
||||
|
||||
sentence = "ISO 639-3 uses the code fas for the dialects spoken across Iran and アフガニスタン (Afghanistan)."
|
||||
entities = ["en:ISO 639-3", "DUMMY_ENTITY", "ja:アフガニスタン", "en:Afghanistan"]
|
||||
spans = [(0, 9), (59, 63), (68, 75), (77, 88)]
|
||||
|
||||
encoding = tokenizer(sentence, entities=entities, entity_spans=spans, return_token_type_ids=True)
|
||||
|
||||
self.assertEqual(
|
||||
tokenizer.decode(encoding["input_ids"], spaces_between_special_tokens=False),
|
||||
"<s> ISO 639-3 uses the code fas for the dialects spoken across Iran and アフガニスタン ( Afghanistan ).</s>",
|
||||
)
|
||||
self.assertEqual(
|
||||
tokenizer.decode(encoding["input_ids"][1:5], spaces_between_special_tokens=False), "ISO 639-3"
|
||||
)
|
||||
self.assertEqual(tokenizer.decode(encoding["input_ids"][17], spaces_between_special_tokens=False), "Iran")
|
||||
self.assertEqual(
|
||||
tokenizer.decode(encoding["input_ids"][20:25], spaces_between_special_tokens=False), "アフガニスタン"
|
||||
)
|
||||
self.assertEqual(
|
||||
tokenizer.decode(encoding["input_ids"][26], spaces_between_special_tokens=False), "Afghanistan"
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
encoding["entity_ids"],
|
||||
[
|
||||
tokenizer.entity_vocab["en:ISO 639-3"],
|
||||
tokenizer.entity_vocab["[UNK]"],
|
||||
tokenizer.entity_vocab["ja:アフガニスタン"],
|
||||
tokenizer.entity_vocab["en:Afghanistan"],
|
||||
],
|
||||
)
|
||||
self.assertEqual(encoding["entity_attention_mask"], [1, 1, 1, 1])
|
||||
self.assertEqual(encoding["entity_token_type_ids"], [0, 0, 0, 0])
|
||||
# fmt: off
|
||||
self.assertEqual(
|
||||
encoding["entity_position_ids"],
|
||||
[
|
||||
[1, 2, 3, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[17, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[19, 20, 21, 22, 23, 24, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[26, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]
|
||||
]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
def test_single_text_padding_pytorch_tensors(self):
|
||||
tokenizer = self.tokenizer
|
||||
|
||||
sentence = "ISO 639-3 uses the code fas for the dialects spoken across Iran and アフガニスタン (Afghanistan)."
|
||||
entities = ["en:ISO 639-3", "DUMMY_ENTITY", "ja:アフガニスタン", "en:Afghanistan"]
|
||||
spans = [(0, 9), (59, 63), (68, 75), (77, 88)]
|
||||
|
||||
encoding = tokenizer(
|
||||
sentence,
|
||||
entities=entities,
|
||||
entity_spans=spans,
|
||||
return_token_type_ids=True,
|
||||
padding="max_length",
|
||||
max_length=30,
|
||||
max_entity_length=16,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
# test words
|
||||
self.assertEqual(encoding["input_ids"].shape, (1, 30))
|
||||
self.assertEqual(encoding["attention_mask"].shape, (1, 30))
|
||||
self.assertEqual(encoding["token_type_ids"].shape, (1, 30))
|
||||
|
||||
# test entities
|
||||
self.assertEqual(encoding["entity_ids"].shape, (1, 16))
|
||||
self.assertEqual(encoding["entity_attention_mask"].shape, (1, 16))
|
||||
self.assertEqual(encoding["entity_token_type_ids"].shape, (1, 16))
|
||||
self.assertEqual(encoding["entity_position_ids"].shape, (1, 16, tokenizer.max_mention_length))
|
||||
|
||||
def test_text_pair_no_padding_or_truncation(self):
|
||||
tokenizer = self.tokenizer
|
||||
|
||||
sentence = "ISO 639-3 uses the code fas"
|
||||
sentence_pair = "for the dialects spoken across Iran and アフガニスタン (Afghanistan)."
|
||||
entities = ["en:ISO 639-3"]
|
||||
entities_pair = ["DUMMY_ENTITY", "ja:アフガニスタン", "en:Afghanistan"]
|
||||
spans = [(0, 9)]
|
||||
spans_pair = [(31, 35), (40, 47), (49, 60)]
|
||||
|
||||
encoding = tokenizer(
|
||||
sentence,
|
||||
sentence_pair,
|
||||
entities=entities,
|
||||
entities_pair=entities_pair,
|
||||
entity_spans=spans,
|
||||
entity_spans_pair=spans_pair,
|
||||
return_token_type_ids=True,
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
tokenizer.decode(encoding["input_ids"], spaces_between_special_tokens=False),
|
||||
"<s> ISO 639-3 uses the code fas</s></s> for the dialects spoken across Iran and アフガニスタン ( Afghanistan ).</s>",
|
||||
)
|
||||
self.assertEqual(
|
||||
tokenizer.decode(encoding["input_ids"][1:5], spaces_between_special_tokens=False), "ISO 639-3"
|
||||
)
|
||||
self.assertEqual(tokenizer.decode(encoding["input_ids"][19], spaces_between_special_tokens=False), "Iran")
|
||||
self.assertEqual(
|
||||
tokenizer.decode(encoding["input_ids"][21:27], spaces_between_special_tokens=False), "アフガニスタン"
|
||||
)
|
||||
self.assertEqual(
|
||||
tokenizer.decode(encoding["input_ids"][28], spaces_between_special_tokens=False), "Afghanistan"
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
encoding["entity_ids"],
|
||||
[
|
||||
tokenizer.entity_vocab["en:ISO 639-3"],
|
||||
tokenizer.entity_vocab["[UNK]"],
|
||||
tokenizer.entity_vocab["ja:アフガニスタン"],
|
||||
tokenizer.entity_vocab["en:Afghanistan"],
|
||||
],
|
||||
)
|
||||
self.assertEqual(encoding["entity_attention_mask"], [1, 1, 1, 1])
|
||||
self.assertEqual(encoding["entity_token_type_ids"], [0, 0, 0, 0])
|
||||
# fmt: off
|
||||
self.assertEqual(
|
||||
encoding["entity_position_ids"],
|
||||
[
|
||||
[1, 2, 3, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[19, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[21, 22, 23, 24, 25, 26, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[28, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]
|
||||
]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
def test_text_pair_only_entity_spans_no_padding_or_truncation(self):
|
||||
tokenizer = self.tokenizer
|
||||
|
||||
sentence = "ISO 639-3 uses the code fas"
|
||||
sentence_pair = "for the dialects spoken across Iran and アフガニスタン (Afghanistan)."
|
||||
entities = ["en:ISO 639-3"]
|
||||
entities_pair = ["DUMMY_ENTITY", "ja:アフガニスタン", "en:Afghanistan"]
|
||||
spans = [(0, 9)]
|
||||
spans_pair = [(31, 35), (40, 47), (49, 60)]
|
||||
|
||||
encoding = tokenizer(
|
||||
sentence,
|
||||
sentence_pair,
|
||||
entities=entities,
|
||||
entities_pair=entities_pair,
|
||||
entity_spans=spans,
|
||||
entity_spans_pair=spans_pair,
|
||||
return_token_type_ids=True,
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
tokenizer.decode(encoding["input_ids"], spaces_between_special_tokens=False),
|
||||
"<s> ISO 639-3 uses the code fas</s></s> for the dialects spoken across Iran and アフガニスタン ( Afghanistan ).</s>",
|
||||
)
|
||||
self.assertEqual(
|
||||
tokenizer.decode(encoding["input_ids"][1:5], spaces_between_special_tokens=False), "ISO 639-3"
|
||||
)
|
||||
self.assertEqual(tokenizer.decode(encoding["input_ids"][19], spaces_between_special_tokens=False), "Iran")
|
||||
self.assertEqual(
|
||||
tokenizer.decode(encoding["input_ids"][21:27], spaces_between_special_tokens=False), "アフガニスタン"
|
||||
)
|
||||
self.assertEqual(
|
||||
tokenizer.decode(encoding["input_ids"][28], spaces_between_special_tokens=False), "Afghanistan"
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
encoding["entity_ids"],
|
||||
[
|
||||
tokenizer.entity_vocab["en:ISO 639-3"],
|
||||
tokenizer.entity_vocab["[UNK]"],
|
||||
tokenizer.entity_vocab["ja:アフガニスタン"],
|
||||
tokenizer.entity_vocab["en:Afghanistan"],
|
||||
],
|
||||
)
|
||||
# fmt: off
|
||||
self.assertEqual(
|
||||
encoding["entity_position_ids"],
|
||||
[
|
||||
[1, 2, 3, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[19, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[21, 22, 23, 24, 25, 26, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[28, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]
|
||||
]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
def test_text_pair_padding_pytorch_tensors(self):
|
||||
tokenizer = self.tokenizer
|
||||
|
||||
sentence = "ISO 639-3 uses the code fas"
|
||||
sentence_pair = "for the dialects spoken across Iran and アフガニスタン (Afghanistan)."
|
||||
entities = ["en:ISO 639-3"]
|
||||
entities_pair = ["DUMMY_ENTITY", "ja:アフガニスタン", "en:Afghanistan"]
|
||||
spans = [(0, 9)]
|
||||
spans_pair = [(31, 35), (40, 47), (49, 60)]
|
||||
|
||||
encoding = tokenizer(
|
||||
sentence,
|
||||
sentence_pair,
|
||||
entities=entities,
|
||||
entities_pair=entities_pair,
|
||||
entity_spans=spans,
|
||||
entity_spans_pair=spans_pair,
|
||||
return_token_type_ids=True,
|
||||
padding="max_length",
|
||||
max_length=40,
|
||||
max_entity_length=16,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
# test words
|
||||
self.assertEqual(encoding["input_ids"].shape, (1, 40))
|
||||
self.assertEqual(encoding["attention_mask"].shape, (1, 40))
|
||||
self.assertEqual(encoding["token_type_ids"].shape, (1, 40))
|
||||
|
||||
# test entities
|
||||
self.assertEqual(encoding["entity_ids"].shape, (1, 16))
|
||||
self.assertEqual(encoding["entity_attention_mask"].shape, (1, 16))
|
||||
self.assertEqual(encoding["entity_token_type_ids"].shape, (1, 16))
|
||||
self.assertEqual(encoding["entity_position_ids"].shape, (1, 16, tokenizer.max_mention_length))
|
||||
|
||||
def test_entity_classification_no_padding_or_truncation(self):
|
||||
tokenizer = self.entity_classification_tokenizer
|
||||
|
||||
sentence = "Japanese is an East Asian language spoken by about 128 million people, primarily in Japan."
|
||||
span = (15, 34)
|
||||
|
||||
encoding = tokenizer(sentence, entity_spans=[span], return_token_type_ids=True)
|
||||
|
||||
# test words
|
||||
self.assertEqual(len(encoding["input_ids"]), 23)
|
||||
self.assertEqual(len(encoding["attention_mask"]), 23)
|
||||
self.assertEqual(len(encoding["token_type_ids"]), 23)
|
||||
self.assertEqual(
|
||||
tokenizer.decode(encoding["input_ids"], spaces_between_special_tokens=False),
|
||||
"<s> Japanese is an<ent>East Asian language<ent>spoken by about 128 million people, primarily in Japan.</s>",
|
||||
)
|
||||
self.assertEqual(
|
||||
tokenizer.decode(encoding["input_ids"][4:9], spaces_between_special_tokens=False),
|
||||
"<ent>East Asian language<ent>",
|
||||
)
|
||||
|
||||
# test entities
|
||||
mask_id = tokenizer.entity_vocab["[MASK]"]
|
||||
self.assertEqual(encoding["entity_ids"], [mask_id])
|
||||
self.assertEqual(encoding["entity_attention_mask"], [1])
|
||||
self.assertEqual(encoding["entity_token_type_ids"], [0])
|
||||
# fmt: off
|
||||
self.assertEqual(
|
||||
encoding["entity_position_ids"],
|
||||
[[4, 5, 6, 7, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
def test_entity_classification_padding_pytorch_tensors(self):
|
||||
tokenizer = self.entity_classification_tokenizer
|
||||
|
||||
sentence = "Japanese is an East Asian language spoken by about 128 million people, primarily in Japan."
|
||||
span = (15, 34)
|
||||
|
||||
encoding = tokenizer(
|
||||
sentence, entity_spans=[span], return_token_type_ids=True, padding="max_length", return_tensors="pt"
|
||||
)
|
||||
|
||||
# test words
|
||||
self.assertEqual(encoding["input_ids"].shape, (1, 512))
|
||||
self.assertEqual(encoding["attention_mask"].shape, (1, 512))
|
||||
self.assertEqual(encoding["token_type_ids"].shape, (1, 512))
|
||||
|
||||
# test entities
|
||||
self.assertEqual(encoding["entity_ids"].shape, (1, 1))
|
||||
self.assertEqual(encoding["entity_attention_mask"].shape, (1, 1))
|
||||
self.assertEqual(encoding["entity_token_type_ids"].shape, (1, 1))
|
||||
self.assertEqual(
|
||||
encoding["entity_position_ids"].shape, (1, tokenizer.max_entity_length, tokenizer.max_mention_length)
|
||||
)
|
||||
|
||||
def test_entity_pair_classification_no_padding_or_truncation(self):
|
||||
tokenizer = self.entity_pair_tokenizer
|
||||
|
||||
sentence = "Japanese is an East Asian language spoken by about 128 million people, primarily in Japan."
|
||||
# head and tail information
|
||||
spans = [(0, 8), (84, 89)]
|
||||
|
||||
encoding = tokenizer(sentence, entity_spans=spans, return_token_type_ids=True)
|
||||
|
||||
self.assertEqual(
|
||||
tokenizer.decode(encoding["input_ids"], spaces_between_special_tokens=False),
|
||||
"<s><ent>Japanese<ent>is an East Asian language spoken by about 128 million people, primarily in<ent2>Japan<ent2>.</s>",
|
||||
)
|
||||
self.assertEqual(
|
||||
tokenizer.decode(encoding["input_ids"][1:4], spaces_between_special_tokens=False),
|
||||
"<ent>Japanese<ent>",
|
||||
)
|
||||
self.assertEqual(
|
||||
tokenizer.decode(encoding["input_ids"][20:23], spaces_between_special_tokens=False), "<ent2>Japan<ent2>"
|
||||
)
|
||||
|
||||
mask_id = tokenizer.entity_vocab["[MASK]"]
|
||||
mask2_id = tokenizer.entity_vocab["[MASK2]"]
|
||||
self.assertEqual(encoding["entity_ids"], [mask_id, mask2_id])
|
||||
self.assertEqual(encoding["entity_attention_mask"], [1, 1])
|
||||
self.assertEqual(encoding["entity_token_type_ids"], [0, 0])
|
||||
# fmt: off
|
||||
self.assertEqual(
|
||||
encoding["entity_position_ids"],
|
||||
[
|
||||
[1, 2, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[20, 21, 22, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]
|
||||
]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
def test_entity_pair_classification_padding_pytorch_tensors(self):
|
||||
tokenizer = self.entity_pair_tokenizer
|
||||
|
||||
sentence = "Japanese is an East Asian language spoken by about 128 million people, primarily in Japan."
|
||||
# head and tail information
|
||||
spans = [(0, 8), (84, 89)]
|
||||
|
||||
encoding = tokenizer(
|
||||
sentence,
|
||||
entity_spans=spans,
|
||||
return_token_type_ids=True,
|
||||
padding="max_length",
|
||||
max_length=30,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
# test words
|
||||
self.assertEqual(encoding["input_ids"].shape, (1, 30))
|
||||
self.assertEqual(encoding["attention_mask"].shape, (1, 30))
|
||||
self.assertEqual(encoding["token_type_ids"].shape, (1, 30))
|
||||
|
||||
# test entities
|
||||
self.assertEqual(encoding["entity_ids"].shape, (1, 2))
|
||||
self.assertEqual(encoding["entity_attention_mask"].shape, (1, 2))
|
||||
self.assertEqual(encoding["entity_token_type_ids"].shape, (1, 2))
|
||||
self.assertEqual(
|
||||
encoding["entity_position_ids"].shape, (1, tokenizer.max_entity_length, tokenizer.max_mention_length)
|
||||
)
|
||||
|
||||
def test_entity_span_classification_no_padding_or_truncation(self):
|
||||
tokenizer = self.entity_span_tokenizer
|
||||
|
||||
sentence = "Japanese is an East Asian language spoken by about 128 million people, primarily in Japan."
|
||||
spans = [(0, 8), (15, 34), (84, 89)]
|
||||
|
||||
encoding = tokenizer(sentence, entity_spans=spans, return_token_type_ids=True)
|
||||
|
||||
self.assertEqual(
|
||||
tokenizer.decode(encoding["input_ids"], spaces_between_special_tokens=False),
|
||||
"<s> Japanese is an East Asian language spoken by about 128 million people, primarily in Japan.</s>",
|
||||
)
|
||||
|
||||
mask_id = tokenizer.entity_vocab["[MASK]"]
|
||||
self.assertEqual(encoding["entity_ids"], [mask_id, mask_id, mask_id])
|
||||
self.assertEqual(encoding["entity_attention_mask"], [1, 1, 1])
|
||||
self.assertEqual(encoding["entity_token_type_ids"], [0, 0, 0])
|
||||
# fmt: off
|
||||
self.assertEqual(
|
||||
encoding["entity_position_ids"],
|
||||
[
|
||||
[1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[4, 5, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
|
||||
[18, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]]
|
||||
)
|
||||
# fmt: on
|
||||
self.assertEqual(encoding["entity_start_positions"], [1, 4, 18])
|
||||
self.assertEqual(encoding["entity_end_positions"], [1, 6, 18])
|
||||
|
||||
def test_entity_span_classification_padding_pytorch_tensors(self):
|
||||
tokenizer = self.entity_span_tokenizer
|
||||
|
||||
sentence = "Japanese is an East Asian language spoken by about 128 million people, primarily in Japan."
|
||||
spans = [(0, 8), (15, 34), (84, 89)]
|
||||
|
||||
encoding = tokenizer(
|
||||
sentence,
|
||||
entity_spans=spans,
|
||||
return_token_type_ids=True,
|
||||
padding="max_length",
|
||||
max_length=30,
|
||||
max_entity_length=16,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
# test words
|
||||
self.assertEqual(encoding["input_ids"].shape, (1, 30))
|
||||
self.assertEqual(encoding["attention_mask"].shape, (1, 30))
|
||||
self.assertEqual(encoding["token_type_ids"].shape, (1, 30))
|
||||
|
||||
# test entities
|
||||
self.assertEqual(encoding["entity_ids"].shape, (1, 16))
|
||||
self.assertEqual(encoding["entity_attention_mask"].shape, (1, 16))
|
||||
self.assertEqual(encoding["entity_token_type_ids"].shape, (1, 16))
|
||||
self.assertEqual(encoding["entity_position_ids"].shape, (1, 16, tokenizer.max_mention_length))
|
||||
self.assertEqual(encoding["entity_start_positions"].shape, (1, 16))
|
||||
self.assertEqual(encoding["entity_end_positions"].shape, (1, 16))
|
|
@ -116,6 +116,7 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
|
|||
"DPRReader",
|
||||
"FlaubertForQuestionAnswering",
|
||||
"GPT2DoubleHeadsModel",
|
||||
"LukeForMaskedLM",
|
||||
"LukeForEntityClassification",
|
||||
"LukeForEntityPairClassification",
|
||||
"LukeForEntitySpanClassification",
|
||||
|
|
Loading…
Reference in New Issue