NerPipeline (TokenClassification) now outputs offsets of words (#8781)
* NerPipeline (TokenClassification) now outputs offsets of words - It happens that the offsets are missing, it forces the user to pattern match the "word" from his input, which is not always feasible. For instance if a sentence contains the same word twice, then there is no way to know which is which. - This PR proposes to fix that by outputting 2 new keys for this pipelines outputs, "start" and "end", which correspond to the string offsets of the word. That means that we should always have the invariant: ```python input[entity["start"]: entity["end"]] == entity["entity_group"] # or entity["entity"] if not grouped ``` * Fixing doc style
This commit is contained in:
parent
5fd3d81ec9
commit
d8fc26e919
|
@ -1420,9 +1420,14 @@ class TokenClassificationPipeline(Pipeline):
|
|||
|
||||
- **word** (:obj:`str`) -- The token/word classified.
|
||||
- **score** (:obj:`float`) -- The corresponding probability for :obj:`entity`.
|
||||
- **entity** (:obj:`str`) -- The entity predicted for that token/word.
|
||||
- **entity** (:obj:`str`) -- The entity predicted for that token/word (it is named `entity_group` when
|
||||
`grouped_entities` is set to True.
|
||||
- **index** (:obj:`int`, only present when ``self.grouped_entities=False``) -- The index of the
|
||||
corresponding token in the sentence.
|
||||
- **start** (:obj:`int`, `optional`) -- The index of the start of the corresponding entity in the sentence.
|
||||
Only exists if the offsets are available within the tokenizer
|
||||
- **end** (:obj:`int`, `optional`) -- The index of the end of the corresponding entity in the sentence.
|
||||
Only exists if the offsets are available within the tokenizer
|
||||
"""
|
||||
|
||||
inputs, offset_mappings = self._args_parser(inputs, **kwargs)
|
||||
|
@ -1486,11 +1491,16 @@ class TokenClassificationPipeline(Pipeline):
|
|||
else:
|
||||
word = self.tokenizer.convert_ids_to_tokens(int(input_ids[idx]))
|
||||
|
||||
start_ind = None
|
||||
end_ind = None
|
||||
|
||||
entity = {
|
||||
"word": word,
|
||||
"score": score[idx][label_idx].item(),
|
||||
"entity": self.model.config.id2label[label_idx],
|
||||
"index": idx,
|
||||
"start": start_ind,
|
||||
"end": end_ind,
|
||||
}
|
||||
|
||||
if self.grouped_entities and self.ignore_subwords:
|
||||
|
@ -1524,6 +1534,8 @@ class TokenClassificationPipeline(Pipeline):
|
|||
"entity_group": entity,
|
||||
"score": np.mean(scores),
|
||||
"word": self.tokenizer.convert_tokens_to_string(tokens),
|
||||
"start": entities[0]["start"],
|
||||
"end": entities[-1]["end"],
|
||||
}
|
||||
return entity_group
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@ import unittest
|
|||
|
||||
from transformers import AutoTokenizer, pipeline
|
||||
from transformers.pipelines import Pipeline, TokenClassificationArgumentHandler
|
||||
from transformers.testing_utils import require_tf, require_torch
|
||||
from transformers.testing_utils import require_tf, require_torch, slow
|
||||
|
||||
from .test_pipelines_common import CustomInputPipelineCommonMixin
|
||||
|
||||
|
@ -18,55 +18,207 @@ class NerPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
|
|||
large_models = [] # Models tested with the @slow decorator
|
||||
|
||||
def _test_pipeline(self, nlp: Pipeline):
|
||||
output_keys = {"entity", "word", "score"}
|
||||
output_keys = {"entity", "word", "score", "start", "end"}
|
||||
if nlp.grouped_entities:
|
||||
output_keys = {"entity_group", "word", "score"}
|
||||
output_keys = {"entity_group", "word", "score", "start", "end"}
|
||||
|
||||
ungrouped_ner_inputs = [
|
||||
[
|
||||
{"entity": "B-PER", "index": 1, "score": 0.9994944930076599, "is_subword": False, "word": "Cons"},
|
||||
{"entity": "B-PER", "index": 2, "score": 0.8025449514389038, "is_subword": True, "word": "##uelo"},
|
||||
{"entity": "I-PER", "index": 3, "score": 0.9993102550506592, "is_subword": False, "word": "Ara"},
|
||||
{"entity": "I-PER", "index": 4, "score": 0.9993743896484375, "is_subword": True, "word": "##új"},
|
||||
{"entity": "I-PER", "index": 5, "score": 0.9992871880531311, "is_subword": True, "word": "##o"},
|
||||
{"entity": "I-PER", "index": 6, "score": 0.9993029236793518, "is_subword": False, "word": "No"},
|
||||
{"entity": "I-PER", "index": 7, "score": 0.9981776475906372, "is_subword": True, "word": "##guera"},
|
||||
{"entity": "B-PER", "index": 15, "score": 0.9998136162757874, "is_subword": False, "word": "Andrés"},
|
||||
{"entity": "I-PER", "index": 16, "score": 0.999740719795227, "is_subword": False, "word": "Pas"},
|
||||
{"entity": "I-PER", "index": 17, "score": 0.9997414350509644, "is_subword": True, "word": "##tran"},
|
||||
{"entity": "I-PER", "index": 18, "score": 0.9996136426925659, "is_subword": True, "word": "##a"},
|
||||
{"entity": "B-ORG", "index": 28, "score": 0.9989739060401917, "is_subword": False, "word": "Far"},
|
||||
{"entity": "I-ORG", "index": 29, "score": 0.7188422083854675, "is_subword": True, "word": "##c"},
|
||||
{
|
||||
"entity": "B-PER",
|
||||
"index": 1,
|
||||
"score": 0.9994944930076599,
|
||||
"is_subword": False,
|
||||
"word": "Cons",
|
||||
"start": 0,
|
||||
"end": 4,
|
||||
},
|
||||
{
|
||||
"entity": "B-PER",
|
||||
"index": 2,
|
||||
"score": 0.8025449514389038,
|
||||
"is_subword": True,
|
||||
"word": "##uelo",
|
||||
"start": 4,
|
||||
"end": 8,
|
||||
},
|
||||
{
|
||||
"entity": "I-PER",
|
||||
"index": 3,
|
||||
"score": 0.9993102550506592,
|
||||
"is_subword": False,
|
||||
"word": "Ara",
|
||||
"start": 9,
|
||||
"end": 11,
|
||||
},
|
||||
{
|
||||
"entity": "I-PER",
|
||||
"index": 4,
|
||||
"score": 0.9993743896484375,
|
||||
"is_subword": True,
|
||||
"word": "##új",
|
||||
"start": 11,
|
||||
"end": 13,
|
||||
},
|
||||
{
|
||||
"entity": "I-PER",
|
||||
"index": 5,
|
||||
"score": 0.9992871880531311,
|
||||
"is_subword": True,
|
||||
"word": "##o",
|
||||
"start": 13,
|
||||
"end": 14,
|
||||
},
|
||||
{
|
||||
"entity": "I-PER",
|
||||
"index": 6,
|
||||
"score": 0.9993029236793518,
|
||||
"is_subword": False,
|
||||
"word": "No",
|
||||
"start": 15,
|
||||
"end": 17,
|
||||
},
|
||||
{
|
||||
"entity": "I-PER",
|
||||
"index": 7,
|
||||
"score": 0.9981776475906372,
|
||||
"is_subword": True,
|
||||
"word": "##guera",
|
||||
"start": 17,
|
||||
"end": 22,
|
||||
},
|
||||
{
|
||||
"entity": "B-PER",
|
||||
"index": 15,
|
||||
"score": 0.9998136162757874,
|
||||
"is_subword": False,
|
||||
"word": "Andrés",
|
||||
"start": 23,
|
||||
"end": 28,
|
||||
},
|
||||
{
|
||||
"entity": "I-PER",
|
||||
"index": 16,
|
||||
"score": 0.999740719795227,
|
||||
"is_subword": False,
|
||||
"word": "Pas",
|
||||
"start": 29,
|
||||
"end": 32,
|
||||
},
|
||||
{
|
||||
"entity": "I-PER",
|
||||
"index": 17,
|
||||
"score": 0.9997414350509644,
|
||||
"is_subword": True,
|
||||
"word": "##tran",
|
||||
"start": 32,
|
||||
"end": 36,
|
||||
},
|
||||
{
|
||||
"entity": "I-PER",
|
||||
"index": 18,
|
||||
"score": 0.9996136426925659,
|
||||
"is_subword": True,
|
||||
"word": "##a",
|
||||
"start": 36,
|
||||
"end": 37,
|
||||
},
|
||||
{
|
||||
"entity": "B-ORG",
|
||||
"index": 28,
|
||||
"score": 0.9989739060401917,
|
||||
"is_subword": False,
|
||||
"word": "Far",
|
||||
"start": 39,
|
||||
"end": 42,
|
||||
},
|
||||
{
|
||||
"entity": "I-ORG",
|
||||
"index": 29,
|
||||
"score": 0.7188422083854675,
|
||||
"is_subword": True,
|
||||
"word": "##c",
|
||||
"start": 42,
|
||||
"end": 43,
|
||||
},
|
||||
],
|
||||
[
|
||||
{"entity": "I-PER", "index": 1, "score": 0.9968166351318359, "is_subword": False, "word": "En"},
|
||||
{"entity": "I-PER", "index": 2, "score": 0.9957635998725891, "is_subword": True, "word": "##zo"},
|
||||
{"entity": "I-ORG", "index": 7, "score": 0.9986497163772583, "is_subword": False, "word": "UN"},
|
||||
{
|
||||
"entity": "I-PER",
|
||||
"index": 1,
|
||||
"score": 0.9968166351318359,
|
||||
"is_subword": False,
|
||||
"word": "En",
|
||||
"start": 0,
|
||||
"end": 2,
|
||||
},
|
||||
{
|
||||
"entity": "I-PER",
|
||||
"index": 2,
|
||||
"score": 0.9957635998725891,
|
||||
"is_subword": True,
|
||||
"word": "##zo",
|
||||
"start": 2,
|
||||
"end": 4,
|
||||
},
|
||||
{
|
||||
"entity": "I-ORG",
|
||||
"index": 7,
|
||||
"score": 0.9986497163772583,
|
||||
"is_subword": False,
|
||||
"word": "UN",
|
||||
"start": 11,
|
||||
"end": 13,
|
||||
},
|
||||
],
|
||||
]
|
||||
|
||||
expected_grouped_ner_results = [
|
||||
[
|
||||
{"entity_group": "PER", "score": 0.999369223912557, "word": "Consuelo Araújo Noguera"},
|
||||
{"entity_group": "PER", "score": 0.9997771680355072, "word": "Andrés Pastrana"},
|
||||
{"entity_group": "ORG", "score": 0.9989739060401917, "word": "Farc"},
|
||||
{
|
||||
"entity_group": "PER",
|
||||
"score": 0.999369223912557,
|
||||
"word": "Consuelo Araújo Noguera",
|
||||
"start": 0,
|
||||
"end": 22,
|
||||
},
|
||||
{
|
||||
"entity_group": "PER",
|
||||
"score": 0.9997771680355072,
|
||||
"word": "Andrés Pastrana",
|
||||
"start": 23,
|
||||
"end": 37,
|
||||
},
|
||||
{"entity_group": "ORG", "score": 0.9989739060401917, "word": "Farc", "start": 39, "end": 43},
|
||||
],
|
||||
[
|
||||
{"entity_group": "PER", "score": 0.9968166351318359, "word": "Enzo"},
|
||||
{"entity_group": "ORG", "score": 0.9986497163772583, "word": "UN"},
|
||||
{"entity_group": "PER", "score": 0.9968166351318359, "word": "Enzo", "start": 0, "end": 4},
|
||||
{"entity_group": "ORG", "score": 0.9986497163772583, "word": "UN", "start": 11, "end": 13},
|
||||
],
|
||||
]
|
||||
|
||||
expected_grouped_ner_results_w_subword = [
|
||||
[
|
||||
{"entity_group": "PER", "score": 0.9994944930076599, "word": "Cons"},
|
||||
{"entity_group": "PER", "score": 0.9663328925768534, "word": "##uelo Araújo Noguera"},
|
||||
{"entity_group": "PER", "score": 0.9997273534536362, "word": "Andrés Pastrana"},
|
||||
{"entity_group": "ORG", "score": 0.8589080572128296, "word": "Farc"},
|
||||
{"entity_group": "PER", "score": 0.9994944930076599, "word": "Cons", "start": 0, "end": 4},
|
||||
{
|
||||
"entity_group": "PER",
|
||||
"score": 0.9663328925768534,
|
||||
"word": "##uelo Araújo Noguera",
|
||||
"start": 4,
|
||||
"end": 22,
|
||||
},
|
||||
{
|
||||
"entity_group": "PER",
|
||||
"score": 0.9997273534536362,
|
||||
"word": "Andrés Pastrana",
|
||||
"start": 23,
|
||||
"end": 37,
|
||||
},
|
||||
{"entity_group": "ORG", "score": 0.8589080572128296, "word": "Farc", "start": 39, "end": 43},
|
||||
],
|
||||
[
|
||||
{"entity_group": "PER", "score": 0.9962901175022125, "word": "Enzo"},
|
||||
{"entity_group": "ORG", "score": 0.9986497163772583, "word": "UN"},
|
||||
{"entity_group": "PER", "score": 0.9962901175022125, "word": "Enzo", "start": 0, "end": 4},
|
||||
{"entity_group": "ORG", "score": 0.9986497163772583, "word": "UN", "start": 11, "end": 13},
|
||||
],
|
||||
]
|
||||
|
||||
|
@ -164,6 +316,34 @@ class NerPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
|
|||
nlp = pipeline(task="ner", model=model_name)
|
||||
self._test_pipeline(nlp)
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_simple(self):
|
||||
nlp = pipeline(task="ner", model="dslim/bert-base-NER", grouped_entities=True)
|
||||
output = nlp("Hello Sarah Jessica Parker who Jessica lives in New York")
|
||||
|
||||
def simplify(output):
|
||||
for i in range(len(output)):
|
||||
output[i]["score"] = round(output[i]["score"], 3)
|
||||
return output
|
||||
|
||||
output = simplify(output)
|
||||
|
||||
self.assertEqual(
|
||||
output,
|
||||
[
|
||||
{
|
||||
"entity_group": "PER",
|
||||
"score": 0.996,
|
||||
"word": "Sarah Jessica Parker",
|
||||
"start": 6,
|
||||
"end": 26,
|
||||
},
|
||||
{"entity_group": "PER", "score": 0.977, "word": "Jessica", "start": 31, "end": 38},
|
||||
{"entity_group": "LOC", "score": 0.999, "word": "New York", "start": 48, "end": 56},
|
||||
],
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_pt_small_ignore_subwords_available_for_fast_tokenizers(self):
|
||||
for model_name in self.small_models:
|
||||
|
|
Loading…
Reference in New Issue