[TokenClassification] Label realignment for subword aggregation (#11680)

* [TokenClassification] Label realignment for subword aggregation

Tentative to replace https://github.com/huggingface/transformers/pull/11622/files

- Added `AggregationStrategy`
- `ignore_subwords` and `grouped_entities` arguments are now fused
  into `aggregation_strategy`. It makes more sense anyway because
  `ignore_subwords=True` with `grouped_entities=False` did not have a
  meaning anyway.
- Added 2 new ways to aggregate which are MAX, and AVERAGE
- AVERAGE requires a bit more information than the others, for now this
case is slightly specific, we should keep that in mind for future
changes.
- Testing has been modified to reflect new argument, and to check the
correct deprecation and the new aggregation_strategy.
- Put the testing argument and testing results for aggregation_strategy,
close together, so that readers can understand what is supposed to
happen.
- `aggregate` is now only tested on a small model as it does not mean
anything to test it globally for all models.
- Previous tests are unchanged in desired output.
- Added a new test case that showcases better the difference between the
  FIRST, MAX and AVERAGE strategies.

* Wrong framework.

* Addressing three issues.

1- Tags might not follow B-, I- convention, so any tag should work now
(assumed as B-TAG)
2- Fixed an issue with average that leads to a substantial code change.
3- The testing suite was not checking for the "index" key for "none"
strategy. This is now fixed.

The issue is that "O" could not be chosen by AVERAGE strategy because
those tokens were filtered out beforehand, so their relative scores were
not counted in the average. Now filtering on
ignore_labels will happen at the very end of the pipeline fixing
that issue.
It's a bit hard to make sure this stays like that because we do
not have a end-to-end test for that behavior

* Formatting.

* Adding formatting to code + cleaner handling of B-, I- tags.

Co-authored-by: Francesco Rubbo <rubbo.francesco@gmail.com>
Co-authored-by: elk-cloner <rezakakhki.rk@gmail.com>

* Typo.

Co-authored-by: Francesco Rubbo <rubbo.francesco@gmail.com>
Co-authored-by: elk-cloner <rezakakhki.rk@gmail.com>
This commit is contained in:
Nicolas Patry 2021-05-18 09:53:20 +02:00 committed by GitHub
parent c73e35323d
commit b88e0e016d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 579 additions and 336 deletions

View File

@ -48,7 +48,12 @@ from .table_question_answering import TableQuestionAnsweringArgumentHandler, Tab
from .text2text_generation import SummarizationPipeline, Text2TextGenerationPipeline, TranslationPipeline
from .text_classification import TextClassificationPipeline
from .text_generation import TextGenerationPipeline
from .token_classification import NerPipeline, TokenClassificationArgumentHandler, TokenClassificationPipeline
from .token_classification import (
AggregationStrategy,
NerPipeline,
TokenClassificationArgumentHandler,
TokenClassificationPipeline,
)
from .zero_shot_classification import ZeroShotClassificationArgumentHandler, ZeroShotClassificationPipeline

View File

@ -1,8 +1,9 @@
from typing import TYPE_CHECKING, List, Optional, Union
import warnings
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import numpy as np
from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available
from ..file_utils import ExplicitEnum, add_end_docstrings, is_tf_available, is_torch_available
from ..modelcard import ModelCard
from ..models.bert.tokenization_bert import BasicTokenizer
from ..tokenization_utils import PreTrainedTokenizer
@ -48,13 +49,43 @@ class TokenClassificationArgumentHandler(ArgumentHandler):
return inputs, offset_mapping
class AggregationStrategy(ExplicitEnum):
"""All the valid aggregation strategies for TokenClassificationPipeline"""
NONE = "none"
SIMPLE = "simple"
FIRST = "first"
AVERAGE = "average"
MAX = "max"
@add_end_docstrings(
PIPELINE_INIT_ARGS,
r"""
ignore_labels (:obj:`List[str]`, defaults to :obj:`["O"]`):
A list of labels to ignore.
grouped_entities (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to group the tokens corresponding to the same entity together in the predictions or not.
DEPRECATED, use :obj:`aggregation_strategy` instead. Whether or not to group the tokens corresponding to
the same entity together in the predictions or not.
aggregation_strategy (:obj:`str`, `optional`, defaults to :obj:`"none"`): The strategy to fuse (or not) tokens based on the model prediction.
- "none" : Will simply not do any aggregation and simply return raw results from the model
- "simple" : Will attempt to group entities following the default schema. (A, B-TAG), (B, I-TAG), (C,
I-TAG), (D, B-TAG2) (E, B-TAG2) will end up being [{"word": ABC, "entity": "TAG"}, {"word": "D",
"entity": "TAG2"}, {"word": "E", "entity": "TAG2"}] Notice that two consecutive B tags will end up as
different entities. On word based languages, we might end up splitting words undesirably : Imagine
Microsoft being tagged as [{"word": "Micro", "entity": "ENTERPRISE"}, {"word": "soft", "entity":
"NAME"}]. Look for FIRST, MAX, AVERAGE for ways to mitigate that and disambiguate words (on languages
that support that meaning, which is basically tokens separated by a space). These mitigations will
only work on real words, "New york" might still be tagged with two different entities.
- "first" : (works only on word based models) Will use the :obj:`SIMPLE` strategy except that words,
cannot end up with different tags. Words will simply use the tag of the first token of the word when
there is ambiguity.
- "average" : (works only on word based models) Will use the :obj:`SIMPLE` strategy except that words,
cannot end up with different tags. scores will be averaged first across tokens, and then the maximum
label is applied.
- "max" : (works only on word based models) Will use the :obj:`SIMPLE` strategy except that words,
cannot end up with different tags. Word entity will simply be the token with the maximum score.
""",
)
class TokenClassificationPipeline(Pipeline):
@ -84,8 +115,9 @@ class TokenClassificationPipeline(Pipeline):
binary_output: bool = False,
ignore_labels=["O"],
task: str = "",
grouped_entities: bool = False,
ignore_subwords: bool = False,
grouped_entities: Optional[bool] = None,
ignore_subwords: Optional[bool] = None,
aggregation_strategy: Optional[AggregationStrategy] = None,
):
super().__init__(
model=model,
@ -106,15 +138,40 @@ class TokenClassificationPipeline(Pipeline):
self._basic_tokenizer = BasicTokenizer(do_lower_case=False)
self._args_parser = args_parser
self.ignore_labels = ignore_labels
self.grouped_entities = grouped_entities
self.ignore_subwords = ignore_subwords
if self.ignore_subwords and not self.tokenizer.is_fast:
if aggregation_strategy is None:
aggregation_strategy = AggregationStrategy.NONE
if grouped_entities is not None or ignore_subwords is not None:
if grouped_entities and ignore_subwords:
aggregation_strategy = AggregationStrategy.FIRST
elif grouped_entities and not ignore_subwords:
aggregation_strategy = AggregationStrategy.SIMPLE
else:
aggregation_strategy = AggregationStrategy.NONE
if grouped_entities is not None:
warnings.warn(
f'`grouped_entities` is deprecated and will be removed in version v5.0.0, defaulted to `aggregation_strategy="{aggregation_strategy}"` instead.'
)
if ignore_subwords is not None:
warnings.warn(
f'`ignore_subwords` is deprecated and will be removed in version v5.0.0, defaulted to `aggregation_strategy="{aggregation_strategy}"` instead.'
)
if isinstance(aggregation_strategy, str):
aggregation_strategy = AggregationStrategy[aggregation_strategy.upper()]
if (
aggregation_strategy in {AggregationStrategy.FIRST, AggregationStrategy.MAX, AggregationStrategy.AVERAGE}
and not self.tokenizer.is_fast
):
raise ValueError(
"Slow tokenizers cannot ignore subwords. Please set the `ignore_subwords` option"
"to `False` or use a fast tokenizer."
"Slow tokenizers cannot handle subwords. Please set the `aggregation_strategy` option"
'to `"simple"` or use a fast tokenizer.'
)
self.aggregation_strategy = aggregation_strategy
def __call__(self, inputs: Union[str, List[str]], **kwargs):
"""
Classify each token of the text(s) given as inputs.
@ -125,14 +182,14 @@ class TokenClassificationPipeline(Pipeline):
Return:
A list or a list of list of :obj:`dict`: Each result comes as a list of dictionaries (one for each token in
the corresponding input, or each entity if this pipeline was instantiated with
:obj:`grouped_entities=True`) with the following keys:
the corresponding input, or each entity if this pipeline was instantiated with an aggregation_strategy)
with the following keys:
- **word** (:obj:`str`) -- The token/word classified.
- **score** (:obj:`float`) -- The corresponding probability for :obj:`entity`.
- **entity** (:obj:`str`) -- The entity predicted for that token/word (it is named `entity_group` when
`grouped_entities` is set to True.
- **index** (:obj:`int`, only present when ``self.grouped_entities=False``) -- The index of the
`aggregation_strategy` is not :obj:`"none"`.
- **index** (:obj:`int`, only present when ``aggregation_strategy="none"``) -- The index of the
corresponding token in the sentence.
- **start** (:obj:`int`, `optional`) -- The index of the start of the corresponding entity in the sentence.
Only exists if the offsets are available within the tokenizer
@ -176,58 +233,142 @@ class TokenClassificationPipeline(Pipeline):
entities = self.model(**tokens)[0][0].cpu().numpy()
input_ids = tokens["input_ids"].cpu().numpy()[0]
score = np.exp(entities) / np.exp(entities).sum(-1, keepdims=True)
labels_idx = score.argmax(axis=-1)
entities = []
# Filter to labels not in `self.ignore_labels`
# Filter special_tokens
filtered_labels_idx = [
(idx, label_idx)
for idx, label_idx in enumerate(labels_idx)
if (self.model.config.id2label[label_idx] not in self.ignore_labels) and not special_tokens_mask[idx]
scores = np.exp(entities) / np.exp(entities).sum(-1, keepdims=True)
pre_entities = self.gather_pre_entities(sentence, input_ids, scores, offset_mapping, special_tokens_mask)
grouped_entities = self.aggregate(pre_entities, self.aggregation_strategy)
# Filter anything that is in self.ignore_labels
entities = [
entity
for entity in grouped_entities
if entity.get("entity", None) not in self.ignore_labels
and entity.get("entity_group", None) not in self.ignore_labels
]
for idx, label_idx in filtered_labels_idx:
if offset_mapping is not None:
start_ind, end_ind = offset_mapping[idx]
word_ref = sentence[start_ind:end_ind]
word = self.tokenizer.convert_ids_to_tokens([int(input_ids[idx])])[0]
is_subword = len(word_ref) != len(word)
if int(input_ids[idx]) == self.tokenizer.unk_token_id:
word = word_ref
is_subword = False
else:
word = self.tokenizer.convert_ids_to_tokens(int(input_ids[idx]))
start_ind = None
end_ind = None
entity = {
"word": word,
"score": score[idx][label_idx].item(),
"entity": self.model.config.id2label[label_idx],
"index": idx,
"start": start_ind,
"end": end_ind,
}
if self.grouped_entities and self.ignore_subwords:
entity["is_subword"] = is_subword
entities += [entity]
if self.grouped_entities:
answers += [self.group_entities(entities)]
# Append ungrouped entities
else:
answers += [entities]
answers.append(entities)
if len(answers) == 1:
return answers[0]
return answers
def gather_pre_entities(
self,
sentence: str,
input_ids: np.ndarray,
scores: np.ndarray,
offset_mapping: Optional[List[Tuple[int, int]]],
special_tokens_mask: np.ndarray,
) -> List[dict]:
"""Fuse various numpy arrays into dicts with all the information needed for aggregation"""
pre_entities = []
for idx, token_scores in enumerate(scores):
# Filter special_tokens, they should only occur
# at the sentence boundaries since we're not encoding pairs of
# sentences so we don't have to keep track of those.
if special_tokens_mask[idx]:
continue
word = self.tokenizer.convert_ids_to_tokens(int(input_ids[idx]))
if offset_mapping is not None:
start_ind, end_ind = offset_mapping[idx]
word_ref = sentence[start_ind:end_ind]
is_subword = len(word_ref) != len(word)
if int(input_ids[idx]) == self.tokenizer.unk_token_id:
word = word_ref
is_subword = False
else:
start_ind = None
end_ind = None
is_subword = False
pre_entity = {
"word": word,
"scores": token_scores,
"start": start_ind,
"end": end_ind,
"index": idx,
"is_subword": is_subword,
}
pre_entities.append(pre_entity)
return pre_entities
def aggregate(self, pre_entities: List[dict], aggregation_strategy: AggregationStrategy) -> List[dict]:
if aggregation_strategy in {AggregationStrategy.NONE, AggregationStrategy.SIMPLE}:
entities = []
for pre_entity in pre_entities:
entity_idx = pre_entity["scores"].argmax()
score = pre_entity["scores"][entity_idx]
entity = {
"entity": self.model.config.id2label[entity_idx],
"score": score,
"index": pre_entity["index"],
"word": pre_entity["word"],
"start": pre_entity["start"],
"end": pre_entity["end"],
}
entities.append(entity)
else:
entities = self.aggregate_words(pre_entities, aggregation_strategy)
if aggregation_strategy == AggregationStrategy.NONE:
return entities
return self.group_entities(entities)
def aggregate_word(self, entities: List[dict], aggregation_strategy: AggregationStrategy) -> dict:
word = self.tokenizer.convert_tokens_to_string([entity["word"] for entity in entities])
if aggregation_strategy == AggregationStrategy.FIRST:
scores = entities[0]["scores"]
idx = scores.argmax()
score = scores[idx]
entity = self.model.config.id2label[idx]
elif aggregation_strategy == AggregationStrategy.MAX:
max_entity = max(entities, key=lambda entity: entity["scores"].max())
scores = max_entity["scores"]
idx = scores.argmax()
score = scores[idx]
entity = self.model.config.id2label[idx]
elif aggregation_strategy == AggregationStrategy.AVERAGE:
scores = np.stack([entity["scores"] for entity in entities])
average_scores = np.nanmean(scores, axis=0)
entity_idx = average_scores.argmax()
entity = self.model.config.id2label[entity_idx]
score = average_scores[entity_idx]
else:
raise ValueError("Invalid aggregation_strategy")
new_entity = {
"entity": entity,
"score": score,
"word": word,
"start": entities[0]["start"],
"end": entities[-1]["end"],
}
return new_entity
def aggregate_words(self, entities: List[dict], aggregation_strategy: AggregationStrategy) -> List[dict]:
"""
Override tokens from a given word that disagree to force agreement on word boundaries.
Example: micro|soft| com|pany| B-ENT I-NAME I-ENT I-ENT will be rewritten with first strategy as microsoft|
company| B-ENT I-ENT
"""
assert aggregation_strategy not in {
AggregationStrategy.NONE,
AggregationStrategy.SIMPLE,
}, "NONE and SIMPLE strategies are invalid"
word_entities = []
word_group = None
for entity in entities:
if word_group is None:
word_group = [entity]
elif entity["is_subword"]:
word_group.append(entity)
else:
word_entities.append(self.aggregate_word(word_group, aggregation_strategy))
word_group = [entity]
# Last item
word_entities.append(self.aggregate_word(word_group, aggregation_strategy))
return word_entities
def group_sub_entities(self, entities: List[dict]) -> dict:
"""
Group together the adjacent tokens with the same entity predicted.
@ -249,6 +390,19 @@ class TokenClassificationPipeline(Pipeline):
}
return entity_group
def get_tag(self, entity_name: str) -> Tuple[str, str]:
if entity_name.startswith("B-"):
bi = "B"
tag = entity_name[2:]
elif entity_name.startswith("I-"):
bi = "I"
tag = entity_name[2:]
else:
# It's not in B-, I- format
bi = "B"
tag = entity_name
return bi, tag
def group_entities(self, entities: List[dict]) -> List[dict]:
"""
Find and group together the adjacent tokens with the same entity predicted.
@ -260,45 +414,29 @@ class TokenClassificationPipeline(Pipeline):
entity_groups = []
entity_group_disagg = []
if entities:
last_idx = entities[-1]["index"]
for entity in entities:
is_last_idx = entity["index"] == last_idx
is_subword = self.ignore_subwords and entity["is_subword"]
if not entity_group_disagg:
entity_group_disagg += [entity]
if is_last_idx:
entity_groups += [self.group_sub_entities(entity_group_disagg)]
entity_group_disagg.append(entity)
continue
# If the current entity is similar and adjacent to the previous entity, append it to the disaggregated entity group
# The split is meant to account for the "B" and "I" suffixes
# If the current entity is similar and adjacent to the previous entity,
# append it to the disaggregated entity group
# The split is meant to account for the "B" and "I" prefixes
# Shouldn't merge if both entities are B-type
if (
(
entity["entity"].split("-")[-1] == entity_group_disagg[-1]["entity"].split("-")[-1]
and entity["entity"].split("-")[0] != "B"
)
and entity["index"] == entity_group_disagg[-1]["index"] + 1
) or is_subword:
# Modify subword type to be previous_type
if is_subword:
entity["entity"] = entity_group_disagg[-1]["entity"].split("-")[-1]
entity["score"] = np.nan # set ignored scores to nan and use np.nanmean
bi, tag = self.get_tag(entity["entity"])
last_bi, last_tag = self.get_tag(entity_group_disagg[-1]["entity"])
entity_group_disagg += [entity]
# Group the entities at the last entity
if is_last_idx:
entity_groups += [self.group_sub_entities(entity_group_disagg)]
# If the current entity is different from the previous entity, aggregate the disaggregated entity group
if tag == last_tag and bi != "B":
# Modify subword type to be previous_type
entity_group_disagg.append(entity)
else:
entity_groups += [self.group_sub_entities(entity_group_disagg)]
# If the current entity is different from the previous entity
# aggregate the disaggregated entity group
entity_groups.append(self.group_sub_entities(entity_group_disagg))
entity_group_disagg = [entity]
# If it's the last entity, add it to the entity groups
if is_last_idx:
entity_groups += [self.group_sub_entities(entity_group_disagg)]
if entity_group_disagg:
# it's the last entity, add it to the entity groups
entity_groups.append(self.group_sub_entities(entity_group_disagg))
return entity_groups

View File

@ -1207,19 +1207,25 @@ def nested_simplify(obj, decimals=3):
Simplifies an object by rounding float numbers, and downcasting tensors/numpy arrays to get simple equality test
within tests.
"""
import numpy as np
from transformers.tokenization_utils import BatchEncoding
if isinstance(obj, list):
return [nested_simplify(item, decimals) for item in obj]
elif isinstance(obj, np.ndarray):
return nested_simplify(obj.tolist())
elif isinstance(obj, (dict, BatchEncoding)):
return {nested_simplify(k, decimals): nested_simplify(v, decimals) for k, v in obj.items()}
elif isinstance(obj, (str, int)):
elif isinstance(obj, (str, int, np.int64)):
return obj
elif is_torch_available() and isinstance(obj, torch.Tensor):
return nested_simplify(obj.tolist())
return nested_simplify(obj.tolist(), decimals)
elif is_tf_available() and tf.is_tensor(obj):
return nested_simplify(obj.numpy().tolist())
elif isinstance(obj, float):
return round(obj, decimals)
elif isinstance(obj, np.float32):
return nested_simplify(obj.item(), decimals)
else:
raise Exception(f"Not supported: {type(obj)}")

View File

@ -14,16 +14,15 @@
import unittest
from transformers import AutoTokenizer, is_torch_available, pipeline
from transformers.pipelines import Pipeline, TokenClassificationArgumentHandler
from transformers.testing_utils import require_tf, require_torch, slow
import numpy as np
from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline
from transformers.pipelines import AggregationStrategy, Pipeline, TokenClassificationArgumentHandler
from transformers.testing_utils import nested_simplify, require_tf, require_torch, slow
from .test_pipelines_common import CustomInputPipelineCommonMixin
if is_torch_available():
import numpy as np
VALID_INPUTS = ["A simple string", ["list of strings", "A simple string that is quite a bit longer"]]
@ -35,210 +34,10 @@ class TokenClassificationPipelineTests(CustomInputPipelineCommonMixin, unittest.
large_models = [] # Models tested with the @slow decorator
def _test_pipeline(self, nlp: Pipeline):
output_keys = {"entity", "word", "score", "start", "end"}
if nlp.grouped_entities:
output_keys = {"entity", "word", "score", "start", "end", "index"}
if nlp.aggregation_strategy != AggregationStrategy.NONE:
output_keys = {"entity_group", "word", "score", "start", "end"}
ungrouped_ner_inputs = [
[
{
"entity": "B-PER",
"index": 1,
"score": 0.9994944930076599,
"is_subword": False,
"word": "Cons",
"start": 0,
"end": 4,
},
{
"entity": "B-PER",
"index": 2,
"score": 0.8025449514389038,
"is_subword": True,
"word": "##uelo",
"start": 4,
"end": 8,
},
{
"entity": "I-PER",
"index": 3,
"score": 0.9993102550506592,
"is_subword": False,
"word": "Ara",
"start": 9,
"end": 11,
},
{
"entity": "I-PER",
"index": 4,
"score": 0.9993743896484375,
"is_subword": True,
"word": "##új",
"start": 11,
"end": 13,
},
{
"entity": "I-PER",
"index": 5,
"score": 0.9992871880531311,
"is_subword": True,
"word": "##o",
"start": 13,
"end": 14,
},
{
"entity": "I-PER",
"index": 6,
"score": 0.9993029236793518,
"is_subword": False,
"word": "No",
"start": 15,
"end": 17,
},
{
"entity": "I-PER",
"index": 7,
"score": 0.9981776475906372,
"is_subword": True,
"word": "##guera",
"start": 17,
"end": 22,
},
{
"entity": "B-PER",
"index": 15,
"score": 0.9998136162757874,
"is_subword": False,
"word": "Andrés",
"start": 23,
"end": 28,
},
{
"entity": "I-PER",
"index": 16,
"score": 0.999740719795227,
"is_subword": False,
"word": "Pas",
"start": 29,
"end": 32,
},
{
"entity": "I-PER",
"index": 17,
"score": 0.9997414350509644,
"is_subword": True,
"word": "##tran",
"start": 32,
"end": 36,
},
{
"entity": "I-PER",
"index": 18,
"score": 0.9996136426925659,
"is_subword": True,
"word": "##a",
"start": 36,
"end": 37,
},
{
"entity": "B-ORG",
"index": 28,
"score": 0.9989739060401917,
"is_subword": False,
"word": "Far",
"start": 39,
"end": 42,
},
{
"entity": "I-ORG",
"index": 29,
"score": 0.7188422083854675,
"is_subword": True,
"word": "##c",
"start": 42,
"end": 43,
},
],
[
{
"entity": "I-PER",
"index": 1,
"score": 0.9968166351318359,
"is_subword": False,
"word": "En",
"start": 0,
"end": 2,
},
{
"entity": "I-PER",
"index": 2,
"score": 0.9957635998725891,
"is_subword": True,
"word": "##zo",
"start": 2,
"end": 4,
},
{
"entity": "I-ORG",
"index": 7,
"score": 0.9986497163772583,
"is_subword": False,
"word": "UN",
"start": 11,
"end": 13,
},
],
]
expected_grouped_ner_results = [
[
{
"entity_group": "PER",
"score": 0.999369223912557,
"word": "Consuelo Araújo Noguera",
"start": 0,
"end": 22,
},
{
"entity_group": "PER",
"score": 0.9997771680355072,
"word": "Andrés Pastrana",
"start": 23,
"end": 37,
},
{"entity_group": "ORG", "score": 0.9989739060401917, "word": "Farc", "start": 39, "end": 43},
],
[
{"entity_group": "PER", "score": 0.9968166351318359, "word": "Enzo", "start": 0, "end": 4},
{"entity_group": "ORG", "score": 0.9986497163772583, "word": "UN", "start": 11, "end": 13},
],
]
expected_grouped_ner_results_w_subword = [
[
{"entity_group": "PER", "score": 0.9994944930076599, "word": "Cons", "start": 0, "end": 4},
{
"entity_group": "PER",
"score": 0.9663328925768534,
"word": "##uelo Araújo Noguera",
"start": 4,
"end": 22,
},
{
"entity_group": "PER",
"score": 0.9997273534536362,
"word": "Andrés Pastrana",
"start": 23,
"end": 37,
},
{"entity_group": "ORG", "score": 0.8589080572128296, "word": "Farc", "start": 39, "end": 43},
],
[
{"entity_group": "PER", "score": 0.9962901175022125, "word": "Enzo", "start": 0, "end": 4},
{"entity_group": "ORG", "score": 0.9986497163772583, "word": "UN", "start": 11, "end": 13},
],
]
self.assertIsNotNone(nlp)
mono_result = nlp(VALID_INPUTS[0])
@ -262,15 +61,306 @@ class TokenClassificationPipelineTests(CustomInputPipelineCommonMixin, unittest.
for key in output_keys:
self.assertIn(key, result)
if nlp.grouped_entities:
if nlp.ignore_subwords:
for ungrouped_input, grouped_result in zip(ungrouped_ner_inputs, expected_grouped_ner_results):
self.assertEqual(nlp.group_entities(ungrouped_input), grouped_result)
else:
for ungrouped_input, grouped_result in zip(
ungrouped_ner_inputs, expected_grouped_ner_results_w_subword
):
self.assertEqual(nlp.group_entities(ungrouped_input), grouped_result)
@require_torch
@slow
def test_spanish_bert(self):
# https://github.com/huggingface/transformers/pull/4987
NER_MODEL = "mrm8488/bert-spanish-cased-finetuned-ner"
model = AutoModelForTokenClassification.from_pretrained(NER_MODEL)
tokenizer = AutoTokenizer.from_pretrained(NER_MODEL, use_fast=True)
sentence = """Consuelo Araújo Noguera, ministra de cultura del presidente Andrés Pastrana (1998.2002) fue asesinada por las Farc luego de haber permanecido secuestrada por algunos meses."""
token_classifier = pipeline("ner", model=model, tokenizer=tokenizer)
output = token_classifier(sentence)
self.assertEqual(
nested_simplify(output[:3]),
[
{"entity": "B-PER", "score": 0.999, "word": "Cons", "start": 0, "end": 4, "index": 1},
{"entity": "B-PER", "score": 0.803, "word": "##uelo", "start": 4, "end": 8, "index": 2},
{"entity": "I-PER", "score": 0.999, "word": "Ara", "start": 9, "end": 12, "index": 3},
],
)
token_classifier = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="simple")
output = token_classifier(sentence)
self.assertEqual(
nested_simplify(output[:3]),
[
{"entity_group": "PER", "score": 0.999, "word": "Cons", "start": 0, "end": 4},
{"entity_group": "PER", "score": 0.966, "word": "##uelo Araújo Noguera", "start": 4, "end": 23},
{"entity_group": "PER", "score": 1.0, "word": "Andrés Pastrana", "start": 60, "end": 75},
],
)
token_classifier = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="first")
output = token_classifier(sentence)
self.assertEqual(
nested_simplify(output[:3]),
[
{"entity_group": "PER", "score": 0.999, "word": "Consuelo Araújo Noguera", "start": 0, "end": 23},
{"entity_group": "PER", "score": 1.0, "word": "Andrés Pastrana", "start": 60, "end": 75},
{"entity_group": "ORG", "score": 0.999, "word": "Farc", "start": 110, "end": 114},
],
)
token_classifier = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="max")
output = token_classifier(sentence)
self.assertEqual(
nested_simplify(output[:3]),
[
{"entity_group": "PER", "score": 0.999, "word": "Consuelo Araújo Noguera", "start": 0, "end": 23},
{"entity_group": "PER", "score": 1.0, "word": "Andrés Pastrana", "start": 60, "end": 75},
{"entity_group": "ORG", "score": 0.999, "word": "Farc", "start": 110, "end": 114},
],
)
token_classifier = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="average")
output = token_classifier(sentence)
self.assertEqual(
nested_simplify(output[:3]),
[
{"entity_group": "PER", "score": 0.966, "word": "Consuelo Araújo Noguera", "start": 0, "end": 23},
{"entity_group": "PER", "score": 1.0, "word": "Andrés Pastrana", "start": 60, "end": 75},
{"entity_group": "ORG", "score": 0.542, "word": "Farc", "start": 110, "end": 114},
],
)
@require_torch
@slow
def test_dbmdz_english(self):
# Other sentence
NER_MODEL = "dbmdz/bert-large-cased-finetuned-conll03-english"
model = AutoModelForTokenClassification.from_pretrained(NER_MODEL)
tokenizer = AutoTokenizer.from_pretrained(NER_MODEL, use_fast=True)
sentence = """Enzo works at the the UN"""
token_classifier = pipeline("ner", model=model, tokenizer=tokenizer)
output = token_classifier(sentence)
self.assertEqual(
nested_simplify(output),
[
{"entity": "I-PER", "score": 0.997, "word": "En", "start": 0, "end": 2, "index": 1},
{"entity": "I-PER", "score": 0.996, "word": "##zo", "start": 2, "end": 4, "index": 2},
{"entity": "I-ORG", "score": 0.999, "word": "UN", "start": 22, "end": 24, "index": 7},
],
)
token_classifier = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="simple")
output = token_classifier(sentence)
self.assertEqual(
nested_simplify(output),
[
{"entity_group": "PER", "score": 0.996, "word": "Enzo", "start": 0, "end": 4},
{"entity_group": "ORG", "score": 0.999, "word": "UN", "start": 22, "end": 24},
],
)
token_classifier = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="first")
output = token_classifier(sentence)
self.assertEqual(
nested_simplify(output[:3]),
[
{"entity_group": "PER", "score": 0.997, "word": "Enzo", "start": 0, "end": 4},
{"entity_group": "ORG", "score": 0.999, "word": "UN", "start": 22, "end": 24},
],
)
token_classifier = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="max")
output = token_classifier(sentence)
self.assertEqual(
nested_simplify(output[:3]),
[
{"entity_group": "PER", "score": 0.997, "word": "Enzo", "start": 0, "end": 4},
{"entity_group": "ORG", "score": 0.999, "word": "UN", "start": 22, "end": 24},
],
)
token_classifier = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="average")
output = token_classifier(sentence)
self.assertEqual(
nested_simplify(output),
[
{"entity_group": "PER", "score": 0.996, "word": "Enzo", "start": 0, "end": 4},
{"entity_group": "ORG", "score": 0.999, "word": "UN", "start": 22, "end": 24},
],
)
@require_torch
def test_aggregation_strategy(self):
model_name = self.small_models[0]
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
token_classifier = pipeline(task="ner", model=model_name, tokenizer=tokenizer, framework="pt")
# Just to understand scores indexes in this test
self.assertEqual(
token_classifier.model.config.id2label,
{0: "O", 1: "B-MISC", 2: "I-MISC", 3: "B-PER", 4: "I-PER", 5: "B-ORG", 6: "I-ORG", 7: "B-LOC", 8: "I-LOC"},
)
example = [
{
# fmt : off
"scores": np.array([0, 0, 0, 0, 0.9968166351318359, 0, 0, 0]),
"index": 1,
"is_subword": False,
"word": "En",
"start": 0,
"end": 2,
},
{
# fmt : off
"scores": np.array([0, 0, 0, 0, 0.9957635998725891, 0, 0, 0]),
"index": 2,
"is_subword": True,
"word": "##zo",
"start": 2,
"end": 4,
},
{
# fmt: off
"scores": np.array([0, 0, 0, 0, 0, 0.9986497163772583, 0, 0, ]),
# fmt: on
"index": 7,
"word": "UN",
"is_subword": False,
"start": 11,
"end": 13,
},
]
self.assertEqual(
nested_simplify(token_classifier.aggregate(example, AggregationStrategy.NONE)),
[
{"end": 2, "entity": "I-PER", "score": 0.997, "start": 0, "word": "En", "index": 1},
{"end": 4, "entity": "I-PER", "score": 0.996, "start": 2, "word": "##zo", "index": 2},
{"end": 13, "entity": "B-ORG", "score": 0.999, "start": 11, "word": "UN", "index": 7},
],
)
self.assertEqual(
nested_simplify(token_classifier.aggregate(example, AggregationStrategy.SIMPLE)),
[
{"entity_group": "PER", "score": 0.996, "word": "Enzo", "start": 0, "end": 4},
{"entity_group": "ORG", "score": 0.999, "word": "UN", "start": 11, "end": 13},
],
)
self.assertEqual(
nested_simplify(token_classifier.aggregate(example, AggregationStrategy.FIRST)),
[
{"entity_group": "PER", "score": 0.997, "word": "Enzo", "start": 0, "end": 4},
{"entity_group": "ORG", "score": 0.999, "word": "UN", "start": 11, "end": 13},
],
)
self.assertEqual(
nested_simplify(token_classifier.aggregate(example, AggregationStrategy.MAX)),
[
{"entity_group": "PER", "score": 0.997, "word": "Enzo", "start": 0, "end": 4},
{"entity_group": "ORG", "score": 0.999, "word": "UN", "start": 11, "end": 13},
],
)
self.assertEqual(
nested_simplify(token_classifier.aggregate(example, AggregationStrategy.AVERAGE)),
[
{"entity_group": "PER", "score": 0.996, "word": "Enzo", "start": 0, "end": 4},
{"entity_group": "ORG", "score": 0.999, "word": "UN", "start": 11, "end": 13},
],
)
@require_torch
def test_aggregation_strategy_example2(self):
model_name = self.small_models[0]
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
token_classifier = pipeline(task="ner", model=model_name, tokenizer=tokenizer, framework="pt")
# Just to understand scores indexes in this test
self.assertEqual(
token_classifier.model.config.id2label,
{0: "O", 1: "B-MISC", 2: "I-MISC", 3: "B-PER", 4: "I-PER", 5: "B-ORG", 6: "I-ORG", 7: "B-LOC", 8: "I-LOC"},
)
example = [
{
# Necessary for AVERAGE
"scores": np.array([0, 0.55, 0, 0.45, 0, 0, 0, 0, 0, 0]),
"is_subword": False,
"index": 1,
"word": "Ra",
"start": 0,
"end": 2,
},
{
"scores": np.array([0, 0, 0, 0.2, 0, 0, 0, 0.8, 0, 0]),
"is_subword": True,
"word": "##ma",
"start": 2,
"end": 4,
"index": 2,
},
{
# 4th score will have the higher average
# 4th score is B-PER for this model
# It's does not correspond to any of the subtokens.
"scores": np.array([0, 0, 0, 0.4, 0, 0, 0.6, 0, 0, 0]),
"is_subword": True,
"word": "##zotti",
"start": 11,
"end": 13,
"index": 3,
},
]
self.assertEqual(
token_classifier.aggregate(example, AggregationStrategy.NONE),
[
{"end": 2, "entity": "B-MISC", "score": 0.55, "start": 0, "word": "Ra", "index": 1},
{"end": 4, "entity": "B-LOC", "score": 0.8, "start": 2, "word": "##ma", "index": 2},
{"end": 13, "entity": "I-ORG", "score": 0.6, "start": 11, "word": "##zotti", "index": 3},
],
)
self.assertEqual(
token_classifier.aggregate(example, AggregationStrategy.FIRST),
[{"entity_group": "MISC", "score": 0.55, "word": "Ramazotti", "start": 0, "end": 13}],
)
self.assertEqual(
token_classifier.aggregate(example, AggregationStrategy.MAX),
[{"entity_group": "LOC", "score": 0.8, "word": "Ramazotti", "start": 0, "end": 13}],
)
self.assertEqual(
nested_simplify(token_classifier.aggregate(example, AggregationStrategy.AVERAGE)),
[{"entity_group": "PER", "score": 0.35, "word": "Ramazotti", "start": 0, "end": 13}],
)
@require_torch
def test_gather_pre_entities(self):
model_name = self.small_models[0]
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
nlp = pipeline(task="ner", model=model_name, tokenizer=tokenizer, framework="pt")
sentence = "Hello there"
tokens = tokenizer(
sentence,
return_attention_mask=False,
return_tensors="pt",
truncation=True,
return_special_tokens_mask=True,
return_offsets_mapping=True,
)
offset_mapping = tokens.pop("offset_mapping").cpu().numpy()[0]
special_tokens_mask = tokens.pop("special_tokens_mask").cpu().numpy()[0]
input_ids = tokens["input_ids"].numpy()[0]
# First element in [CLS]
scores = np.array([[1, 0, 0], [0.1, 0.3, 0.6], [0.8, 0.1, 0.1]])
pre_entities = nlp.gather_pre_entities(sentence, input_ids, scores, offset_mapping, special_tokens_mask)
self.assertEqual(
nested_simplify(pre_entities),
[
{"word": "Hello", "scores": [0.1, 0.3, 0.6], "start": 0, "end": 5, "is_subword": False, "index": 1},
{
"word": "there",
"scores": [0.8, 0.1, 0.1],
"index": 2,
"start": 6,
"end": 11,
"is_subword": False,
},
],
)
@require_tf
def test_tf_only(self):
@ -295,8 +385,7 @@ class TokenClassificationPipelineTests(CustomInputPipelineCommonMixin, unittest.
model=model_name,
tokenizer=tokenizer,
framework="tf",
grouped_entities=True,
ignore_subwords=True,
aggregation_strategy=AggregationStrategy.FIRST,
)
self._test_pipeline(nlp)
@ -307,18 +396,23 @@ class TokenClassificationPipelineTests(CustomInputPipelineCommonMixin, unittest.
model=model_name,
tokenizer=tokenizer,
framework="tf",
grouped_entities=True,
ignore_subwords=False,
aggregation_strategy=AggregationStrategy.SIMPLE,
)
self._test_pipeline(nlp)
@require_torch
def test_pt_ignore_subwords_slow_tokenizer_raises(self):
for model_name in self.small_models:
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
model_name = self.small_models[0]
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
with self.assertRaises(ValueError):
pipeline(task="ner", model=model_name, tokenizer=tokenizer, ignore_subwords=True, use_fast=False)
with self.assertRaises(ValueError):
pipeline(task="ner", model=model_name, tokenizer=tokenizer, aggregation_strategy=AggregationStrategy.FIRST)
with self.assertRaises(ValueError):
pipeline(
task="ner", model=model_name, tokenizer=tokenizer, aggregation_strategy=AggregationStrategy.AVERAGE
)
with self.assertRaises(ValueError):
pipeline(task="ner", model=model_name, tokenizer=tokenizer, aggregation_strategy=AggregationStrategy.MAX)
@require_torch
def test_pt_defaults_slow_tokenizer(self):
@ -333,27 +427,27 @@ class TokenClassificationPipelineTests(CustomInputPipelineCommonMixin, unittest.
nlp = pipeline(task="ner", model=model_name)
self._test_pipeline(nlp)
@slow
@require_torch
def test_warnings(self):
with self.assertWarns(UserWarning):
token_classifier = pipeline(task="ner", model=self.small_models[0], grouped_entities=True)
self.assertEqual(token_classifier.aggregation_strategy, AggregationStrategy.SIMPLE)
with self.assertWarns(UserWarning):
token_classifier = pipeline(
task="ner", model=self.small_models[0], grouped_entities=True, ignore_subwords=True
)
self.assertEqual(token_classifier.aggregation_strategy, AggregationStrategy.FIRST)
@slow
@require_torch
def test_simple(self):
nlp = pipeline(task="ner", model="dslim/bert-base-NER", grouped_entities=True)
nlp = pipeline(task="ner", model="dslim/bert-base-NER", aggregation_strategy=AggregationStrategy.SIMPLE)
sentence = "Hello Sarah Jessica Parker who Jessica lives in New York"
sentence2 = "This is a simple test"
output = nlp(sentence)
def simplify(output):
if isinstance(output, (list, tuple)):
return [simplify(item) for item in output]
elif isinstance(output, dict):
return {simplify(k): simplify(v) for k, v in output.items()}
elif isinstance(output, (str, int, np.int64)):
return output
elif isinstance(output, float):
return round(output, 3)
else:
raise Exception(f"Cannot handle {type(output)}")
output_ = simplify(output)
output_ = nested_simplify(output)
self.assertEqual(
output_,
@ -371,7 +465,7 @@ class TokenClassificationPipelineTests(CustomInputPipelineCommonMixin, unittest.
)
output = nlp([sentence, sentence2])
output_ = simplify(output)
output_ = nested_simplify(output)
self.assertEqual(
output_,
@ -390,14 +484,14 @@ class TokenClassificationPipelineTests(CustomInputPipelineCommonMixin, unittest.
for model_name in self.small_models:
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
nlp = pipeline(
task="ner", model=model_name, tokenizer=tokenizer, grouped_entities=True, ignore_subwords=True
task="ner", model=model_name, tokenizer=tokenizer, aggregation_strategy=AggregationStrategy.FIRST
)
self._test_pipeline(nlp)
for model_name in self.small_models:
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
nlp = pipeline(
task="ner", model=model_name, tokenizer=tokenizer, grouped_entities=True, ignore_subwords=False
task="ner", model=model_name, tokenizer=tokenizer, aggregation_strategy=AggregationStrategy.SIMPLE
)
self._test_pipeline(nlp)