NerPipeline (TokenClassification) now outputs offsets of words (#8781)

* NerPipeline (TokenClassification) now outputs offsets of words

- It happens that the offsets are missing, it forces the user to pattern
match the "word" from his input, which is not always feasible.
For instance if a sentence contains the same word twice, then there
is no way to know which is which.
- This PR proposes to fix that by outputting 2 new keys for this
pipelines outputs, "start" and "end", which correspond to the string
offsets of the word. That means that we should always have the
invariant:

```python
input[entity["start"]: entity["end"]] == entity["entity_group"]
                                    # or entity["entity"] if not grouped
```

* Fixing doc style
This commit is contained in:
Nicolas Patry 2020-11-30 20:05:08 +01:00 committed by GitHub
parent 5fd3d81ec9
commit d8fc26e919
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 223 additions and 31 deletions

View File

@ -1420,9 +1420,14 @@ class TokenClassificationPipeline(Pipeline):
- **word** (:obj:`str`) -- The token/word classified.
- **score** (:obj:`float`) -- The corresponding probability for :obj:`entity`.
- **entity** (:obj:`str`) -- The entity predicted for that token/word.
- **entity** (:obj:`str`) -- The entity predicted for that token/word (it is named `entity_group` when
`grouped_entities` is set to True.
- **index** (:obj:`int`, only present when ``self.grouped_entities=False``) -- The index of the
corresponding token in the sentence.
- **start** (:obj:`int`, `optional`) -- The index of the start of the corresponding entity in the sentence.
Only exists if the offsets are available within the tokenizer
- **end** (:obj:`int`, `optional`) -- The index of the end of the corresponding entity in the sentence.
Only exists if the offsets are available within the tokenizer
"""
inputs, offset_mappings = self._args_parser(inputs, **kwargs)
@ -1486,11 +1491,16 @@ class TokenClassificationPipeline(Pipeline):
else:
word = self.tokenizer.convert_ids_to_tokens(int(input_ids[idx]))
start_ind = None
end_ind = None
entity = {
"word": word,
"score": score[idx][label_idx].item(),
"entity": self.model.config.id2label[label_idx],
"index": idx,
"start": start_ind,
"end": end_ind,
}
if self.grouped_entities and self.ignore_subwords:
@ -1524,6 +1534,8 @@ class TokenClassificationPipeline(Pipeline):
"entity_group": entity,
"score": np.mean(scores),
"word": self.tokenizer.convert_tokens_to_string(tokens),
"start": entities[0]["start"],
"end": entities[-1]["end"],
}
return entity_group

View File

@ -2,7 +2,7 @@ import unittest
from transformers import AutoTokenizer, pipeline
from transformers.pipelines import Pipeline, TokenClassificationArgumentHandler
from transformers.testing_utils import require_tf, require_torch
from transformers.testing_utils import require_tf, require_torch, slow
from .test_pipelines_common import CustomInputPipelineCommonMixin
@ -18,55 +18,207 @@ class NerPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
large_models = [] # Models tested with the @slow decorator
def _test_pipeline(self, nlp: Pipeline):
output_keys = {"entity", "word", "score"}
output_keys = {"entity", "word", "score", "start", "end"}
if nlp.grouped_entities:
output_keys = {"entity_group", "word", "score"}
output_keys = {"entity_group", "word", "score", "start", "end"}
ungrouped_ner_inputs = [
[
{"entity": "B-PER", "index": 1, "score": 0.9994944930076599, "is_subword": False, "word": "Cons"},
{"entity": "B-PER", "index": 2, "score": 0.8025449514389038, "is_subword": True, "word": "##uelo"},
{"entity": "I-PER", "index": 3, "score": 0.9993102550506592, "is_subword": False, "word": "Ara"},
{"entity": "I-PER", "index": 4, "score": 0.9993743896484375, "is_subword": True, "word": "##új"},
{"entity": "I-PER", "index": 5, "score": 0.9992871880531311, "is_subword": True, "word": "##o"},
{"entity": "I-PER", "index": 6, "score": 0.9993029236793518, "is_subword": False, "word": "No"},
{"entity": "I-PER", "index": 7, "score": 0.9981776475906372, "is_subword": True, "word": "##guera"},
{"entity": "B-PER", "index": 15, "score": 0.9998136162757874, "is_subword": False, "word": "Andrés"},
{"entity": "I-PER", "index": 16, "score": 0.999740719795227, "is_subword": False, "word": "Pas"},
{"entity": "I-PER", "index": 17, "score": 0.9997414350509644, "is_subword": True, "word": "##tran"},
{"entity": "I-PER", "index": 18, "score": 0.9996136426925659, "is_subword": True, "word": "##a"},
{"entity": "B-ORG", "index": 28, "score": 0.9989739060401917, "is_subword": False, "word": "Far"},
{"entity": "I-ORG", "index": 29, "score": 0.7188422083854675, "is_subword": True, "word": "##c"},
{
"entity": "B-PER",
"index": 1,
"score": 0.9994944930076599,
"is_subword": False,
"word": "Cons",
"start": 0,
"end": 4,
},
{
"entity": "B-PER",
"index": 2,
"score": 0.8025449514389038,
"is_subword": True,
"word": "##uelo",
"start": 4,
"end": 8,
},
{
"entity": "I-PER",
"index": 3,
"score": 0.9993102550506592,
"is_subword": False,
"word": "Ara",
"start": 9,
"end": 11,
},
{
"entity": "I-PER",
"index": 4,
"score": 0.9993743896484375,
"is_subword": True,
"word": "##új",
"start": 11,
"end": 13,
},
{
"entity": "I-PER",
"index": 5,
"score": 0.9992871880531311,
"is_subword": True,
"word": "##o",
"start": 13,
"end": 14,
},
{
"entity": "I-PER",
"index": 6,
"score": 0.9993029236793518,
"is_subword": False,
"word": "No",
"start": 15,
"end": 17,
},
{
"entity": "I-PER",
"index": 7,
"score": 0.9981776475906372,
"is_subword": True,
"word": "##guera",
"start": 17,
"end": 22,
},
{
"entity": "B-PER",
"index": 15,
"score": 0.9998136162757874,
"is_subword": False,
"word": "Andrés",
"start": 23,
"end": 28,
},
{
"entity": "I-PER",
"index": 16,
"score": 0.999740719795227,
"is_subword": False,
"word": "Pas",
"start": 29,
"end": 32,
},
{
"entity": "I-PER",
"index": 17,
"score": 0.9997414350509644,
"is_subword": True,
"word": "##tran",
"start": 32,
"end": 36,
},
{
"entity": "I-PER",
"index": 18,
"score": 0.9996136426925659,
"is_subword": True,
"word": "##a",
"start": 36,
"end": 37,
},
{
"entity": "B-ORG",
"index": 28,
"score": 0.9989739060401917,
"is_subword": False,
"word": "Far",
"start": 39,
"end": 42,
},
{
"entity": "I-ORG",
"index": 29,
"score": 0.7188422083854675,
"is_subword": True,
"word": "##c",
"start": 42,
"end": 43,
},
],
[
{"entity": "I-PER", "index": 1, "score": 0.9968166351318359, "is_subword": False, "word": "En"},
{"entity": "I-PER", "index": 2, "score": 0.9957635998725891, "is_subword": True, "word": "##zo"},
{"entity": "I-ORG", "index": 7, "score": 0.9986497163772583, "is_subword": False, "word": "UN"},
{
"entity": "I-PER",
"index": 1,
"score": 0.9968166351318359,
"is_subword": False,
"word": "En",
"start": 0,
"end": 2,
},
{
"entity": "I-PER",
"index": 2,
"score": 0.9957635998725891,
"is_subword": True,
"word": "##zo",
"start": 2,
"end": 4,
},
{
"entity": "I-ORG",
"index": 7,
"score": 0.9986497163772583,
"is_subword": False,
"word": "UN",
"start": 11,
"end": 13,
},
],
]
expected_grouped_ner_results = [
[
{"entity_group": "PER", "score": 0.999369223912557, "word": "Consuelo Araújo Noguera"},
{"entity_group": "PER", "score": 0.9997771680355072, "word": "Andrés Pastrana"},
{"entity_group": "ORG", "score": 0.9989739060401917, "word": "Farc"},
{
"entity_group": "PER",
"score": 0.999369223912557,
"word": "Consuelo Araújo Noguera",
"start": 0,
"end": 22,
},
{
"entity_group": "PER",
"score": 0.9997771680355072,
"word": "Andrés Pastrana",
"start": 23,
"end": 37,
},
{"entity_group": "ORG", "score": 0.9989739060401917, "word": "Farc", "start": 39, "end": 43},
],
[
{"entity_group": "PER", "score": 0.9968166351318359, "word": "Enzo"},
{"entity_group": "ORG", "score": 0.9986497163772583, "word": "UN"},
{"entity_group": "PER", "score": 0.9968166351318359, "word": "Enzo", "start": 0, "end": 4},
{"entity_group": "ORG", "score": 0.9986497163772583, "word": "UN", "start": 11, "end": 13},
],
]
expected_grouped_ner_results_w_subword = [
[
{"entity_group": "PER", "score": 0.9994944930076599, "word": "Cons"},
{"entity_group": "PER", "score": 0.9663328925768534, "word": "##uelo Araújo Noguera"},
{"entity_group": "PER", "score": 0.9997273534536362, "word": "Andrés Pastrana"},
{"entity_group": "ORG", "score": 0.8589080572128296, "word": "Farc"},
{"entity_group": "PER", "score": 0.9994944930076599, "word": "Cons", "start": 0, "end": 4},
{
"entity_group": "PER",
"score": 0.9663328925768534,
"word": "##uelo Araújo Noguera",
"start": 4,
"end": 22,
},
{
"entity_group": "PER",
"score": 0.9997273534536362,
"word": "Andrés Pastrana",
"start": 23,
"end": 37,
},
{"entity_group": "ORG", "score": 0.8589080572128296, "word": "Farc", "start": 39, "end": 43},
],
[
{"entity_group": "PER", "score": 0.9962901175022125, "word": "Enzo"},
{"entity_group": "ORG", "score": 0.9986497163772583, "word": "UN"},
{"entity_group": "PER", "score": 0.9962901175022125, "word": "Enzo", "start": 0, "end": 4},
{"entity_group": "ORG", "score": 0.9986497163772583, "word": "UN", "start": 11, "end": 13},
],
]
@ -164,6 +316,34 @@ class NerPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
nlp = pipeline(task="ner", model=model_name)
self._test_pipeline(nlp)
@slow
@require_torch
def test_simple(self):
nlp = pipeline(task="ner", model="dslim/bert-base-NER", grouped_entities=True)
output = nlp("Hello Sarah Jessica Parker who Jessica lives in New York")
def simplify(output):
for i in range(len(output)):
output[i]["score"] = round(output[i]["score"], 3)
return output
output = simplify(output)
self.assertEqual(
output,
[
{
"entity_group": "PER",
"score": 0.996,
"word": "Sarah Jessica Parker",
"start": 6,
"end": 26,
},
{"entity_group": "PER", "score": 0.977, "word": "Jessica", "start": 31, "end": 38},
{"entity_group": "LOC", "score": 0.999, "word": "New York", "start": 48, "end": 56},
],
)
@require_torch
def test_pt_small_ignore_subwords_available_for_fast_tokenizers(self):
for model_name in self.small_models: