From b88e0e016db8cc9dc3e21db1bcc2615a1717ddcb Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 18 May 2021 09:53:20 +0200 Subject: [PATCH] [TokenClassification] Label realignment for subword aggregation (#11680) * [TokenClassification] Label realignment for subword aggregation Tentative to replace https://github.com/huggingface/transformers/pull/11622/files - Added `AggregationStrategy` - `ignore_subwords` and `grouped_entities` arguments are now fused into `aggregation_strategy`. It makes more sense anyway because `ignore_subwords=True` with `grouped_entities=False` did not have a meaning anyway. - Added 2 new ways to aggregate which are MAX, and AVERAGE - AVERAGE requires a bit more information than the others, for now this case is slightly specific, we should keep that in mind for future changes. - Testing has been modified to reflect new argument, and to check the correct deprecation and the new aggregation_strategy. - Put the testing argument and testing results for aggregation_strategy, close together, so that readers can understand what is supposed to happen. - `aggregate` is now only tested on a small model as it does not mean anything to test it globally for all models. - Previous tests are unchanged in desired output. - Added a new test case that showcases better the difference between the FIRST, MAX and AVERAGE strategies. * Wrong framework. * Addressing three issues. 1- Tags might not follow B-, I- convention, so any tag should work now (assumed as B-TAG) 2- Fixed an issue with average that leads to a substantial code change. 3- The testing suite was not checking for the "index" key for "none" strategy. This is now fixed. The issue is that "O" could not be chosen by AVERAGE strategy because those tokens were filtered out beforehand, so their relative scores were not counted in the average. Now filtering on ignore_labels will happen at the very end of the pipeline fixing that issue. It's a bit hard to make sure this stays like that because we do not have a end-to-end test for that behavior * Formatting. * Adding formatting to code + cleaner handling of B-, I- tags. Co-authored-by: Francesco Rubbo Co-authored-by: elk-cloner * Typo. Co-authored-by: Francesco Rubbo Co-authored-by: elk-cloner --- src/transformers/pipelines/__init__.py | 7 +- .../pipelines/token_classification.py | 320 +++++++--- src/transformers/testing_utils.py | 10 +- tests/test_pipelines_token_classification.py | 578 ++++++++++-------- 4 files changed, 579 insertions(+), 336 deletions(-) diff --git a/src/transformers/pipelines/__init__.py b/src/transformers/pipelines/__init__.py index 67061060aa..33f3fe12e1 100755 --- a/src/transformers/pipelines/__init__.py +++ b/src/transformers/pipelines/__init__.py @@ -48,7 +48,12 @@ from .table_question_answering import TableQuestionAnsweringArgumentHandler, Tab from .text2text_generation import SummarizationPipeline, Text2TextGenerationPipeline, TranslationPipeline from .text_classification import TextClassificationPipeline from .text_generation import TextGenerationPipeline -from .token_classification import NerPipeline, TokenClassificationArgumentHandler, TokenClassificationPipeline +from .token_classification import ( + AggregationStrategy, + NerPipeline, + TokenClassificationArgumentHandler, + TokenClassificationPipeline, +) from .zero_shot_classification import ZeroShotClassificationArgumentHandler, ZeroShotClassificationPipeline diff --git a/src/transformers/pipelines/token_classification.py b/src/transformers/pipelines/token_classification.py index d9431c0cb7..3d155dcbfe 100644 --- a/src/transformers/pipelines/token_classification.py +++ b/src/transformers/pipelines/token_classification.py @@ -1,8 +1,9 @@ -from typing import TYPE_CHECKING, List, Optional, Union +import warnings +from typing import TYPE_CHECKING, List, Optional, Tuple, Union import numpy as np -from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available +from ..file_utils import ExplicitEnum, add_end_docstrings, is_tf_available, is_torch_available from ..modelcard import ModelCard from ..models.bert.tokenization_bert import BasicTokenizer from ..tokenization_utils import PreTrainedTokenizer @@ -48,13 +49,43 @@ class TokenClassificationArgumentHandler(ArgumentHandler): return inputs, offset_mapping +class AggregationStrategy(ExplicitEnum): + """All the valid aggregation strategies for TokenClassificationPipeline""" + + NONE = "none" + SIMPLE = "simple" + FIRST = "first" + AVERAGE = "average" + MAX = "max" + + @add_end_docstrings( PIPELINE_INIT_ARGS, r""" ignore_labels (:obj:`List[str]`, defaults to :obj:`["O"]`): A list of labels to ignore. grouped_entities (:obj:`bool`, `optional`, defaults to :obj:`False`): - Whether or not to group the tokens corresponding to the same entity together in the predictions or not. + DEPRECATED, use :obj:`aggregation_strategy` instead. Whether or not to group the tokens corresponding to + the same entity together in the predictions or not. + aggregation_strategy (:obj:`str`, `optional`, defaults to :obj:`"none"`): The strategy to fuse (or not) tokens based on the model prediction. + + - "none" : Will simply not do any aggregation and simply return raw results from the model + - "simple" : Will attempt to group entities following the default schema. (A, B-TAG), (B, I-TAG), (C, + I-TAG), (D, B-TAG2) (E, B-TAG2) will end up being [{"word": ABC, "entity": "TAG"}, {"word": "D", + "entity": "TAG2"}, {"word": "E", "entity": "TAG2"}] Notice that two consecutive B tags will end up as + different entities. On word based languages, we might end up splitting words undesirably : Imagine + Microsoft being tagged as [{"word": "Micro", "entity": "ENTERPRISE"}, {"word": "soft", "entity": + "NAME"}]. Look for FIRST, MAX, AVERAGE for ways to mitigate that and disambiguate words (on languages + that support that meaning, which is basically tokens separated by a space). These mitigations will + only work on real words, "New york" might still be tagged with two different entities. + - "first" : (works only on word based models) Will use the :obj:`SIMPLE` strategy except that words, + cannot end up with different tags. Words will simply use the tag of the first token of the word when + there is ambiguity. + - "average" : (works only on word based models) Will use the :obj:`SIMPLE` strategy except that words, + cannot end up with different tags. scores will be averaged first across tokens, and then the maximum + label is applied. + - "max" : (works only on word based models) Will use the :obj:`SIMPLE` strategy except that words, + cannot end up with different tags. Word entity will simply be the token with the maximum score. """, ) class TokenClassificationPipeline(Pipeline): @@ -84,8 +115,9 @@ class TokenClassificationPipeline(Pipeline): binary_output: bool = False, ignore_labels=["O"], task: str = "", - grouped_entities: bool = False, - ignore_subwords: bool = False, + grouped_entities: Optional[bool] = None, + ignore_subwords: Optional[bool] = None, + aggregation_strategy: Optional[AggregationStrategy] = None, ): super().__init__( model=model, @@ -106,15 +138,40 @@ class TokenClassificationPipeline(Pipeline): self._basic_tokenizer = BasicTokenizer(do_lower_case=False) self._args_parser = args_parser self.ignore_labels = ignore_labels - self.grouped_entities = grouped_entities - self.ignore_subwords = ignore_subwords - if self.ignore_subwords and not self.tokenizer.is_fast: + if aggregation_strategy is None: + aggregation_strategy = AggregationStrategy.NONE + if grouped_entities is not None or ignore_subwords is not None: + + if grouped_entities and ignore_subwords: + aggregation_strategy = AggregationStrategy.FIRST + elif grouped_entities and not ignore_subwords: + aggregation_strategy = AggregationStrategy.SIMPLE + else: + aggregation_strategy = AggregationStrategy.NONE + + if grouped_entities is not None: + warnings.warn( + f'`grouped_entities` is deprecated and will be removed in version v5.0.0, defaulted to `aggregation_strategy="{aggregation_strategy}"` instead.' + ) + if ignore_subwords is not None: + warnings.warn( + f'`ignore_subwords` is deprecated and will be removed in version v5.0.0, defaulted to `aggregation_strategy="{aggregation_strategy}"` instead.' + ) + if isinstance(aggregation_strategy, str): + aggregation_strategy = AggregationStrategy[aggregation_strategy.upper()] + + if ( + aggregation_strategy in {AggregationStrategy.FIRST, AggregationStrategy.MAX, AggregationStrategy.AVERAGE} + and not self.tokenizer.is_fast + ): raise ValueError( - "Slow tokenizers cannot ignore subwords. Please set the `ignore_subwords` option" - "to `False` or use a fast tokenizer." + "Slow tokenizers cannot handle subwords. Please set the `aggregation_strategy` option" + 'to `"simple"` or use a fast tokenizer.' ) + self.aggregation_strategy = aggregation_strategy + def __call__(self, inputs: Union[str, List[str]], **kwargs): """ Classify each token of the text(s) given as inputs. @@ -125,14 +182,14 @@ class TokenClassificationPipeline(Pipeline): Return: A list or a list of list of :obj:`dict`: Each result comes as a list of dictionaries (one for each token in - the corresponding input, or each entity if this pipeline was instantiated with - :obj:`grouped_entities=True`) with the following keys: + the corresponding input, or each entity if this pipeline was instantiated with an aggregation_strategy) + with the following keys: - **word** (:obj:`str`) -- The token/word classified. - **score** (:obj:`float`) -- The corresponding probability for :obj:`entity`. - **entity** (:obj:`str`) -- The entity predicted for that token/word (it is named `entity_group` when - `grouped_entities` is set to True. - - **index** (:obj:`int`, only present when ``self.grouped_entities=False``) -- The index of the + `aggregation_strategy` is not :obj:`"none"`. + - **index** (:obj:`int`, only present when ``aggregation_strategy="none"``) -- The index of the corresponding token in the sentence. - **start** (:obj:`int`, `optional`) -- The index of the start of the corresponding entity in the sentence. Only exists if the offsets are available within the tokenizer @@ -176,58 +233,142 @@ class TokenClassificationPipeline(Pipeline): entities = self.model(**tokens)[0][0].cpu().numpy() input_ids = tokens["input_ids"].cpu().numpy()[0] - score = np.exp(entities) / np.exp(entities).sum(-1, keepdims=True) - labels_idx = score.argmax(axis=-1) - - entities = [] - # Filter to labels not in `self.ignore_labels` - # Filter special_tokens - filtered_labels_idx = [ - (idx, label_idx) - for idx, label_idx in enumerate(labels_idx) - if (self.model.config.id2label[label_idx] not in self.ignore_labels) and not special_tokens_mask[idx] + scores = np.exp(entities) / np.exp(entities).sum(-1, keepdims=True) + pre_entities = self.gather_pre_entities(sentence, input_ids, scores, offset_mapping, special_tokens_mask) + grouped_entities = self.aggregate(pre_entities, self.aggregation_strategy) + # Filter anything that is in self.ignore_labels + entities = [ + entity + for entity in grouped_entities + if entity.get("entity", None) not in self.ignore_labels + and entity.get("entity_group", None) not in self.ignore_labels ] - - for idx, label_idx in filtered_labels_idx: - if offset_mapping is not None: - start_ind, end_ind = offset_mapping[idx] - word_ref = sentence[start_ind:end_ind] - word = self.tokenizer.convert_ids_to_tokens([int(input_ids[idx])])[0] - is_subword = len(word_ref) != len(word) - - if int(input_ids[idx]) == self.tokenizer.unk_token_id: - word = word_ref - is_subword = False - else: - word = self.tokenizer.convert_ids_to_tokens(int(input_ids[idx])) - - start_ind = None - end_ind = None - - entity = { - "word": word, - "score": score[idx][label_idx].item(), - "entity": self.model.config.id2label[label_idx], - "index": idx, - "start": start_ind, - "end": end_ind, - } - - if self.grouped_entities and self.ignore_subwords: - entity["is_subword"] = is_subword - - entities += [entity] - - if self.grouped_entities: - answers += [self.group_entities(entities)] - # Append ungrouped entities - else: - answers += [entities] + answers.append(entities) if len(answers) == 1: return answers[0] return answers + def gather_pre_entities( + self, + sentence: str, + input_ids: np.ndarray, + scores: np.ndarray, + offset_mapping: Optional[List[Tuple[int, int]]], + special_tokens_mask: np.ndarray, + ) -> List[dict]: + """Fuse various numpy arrays into dicts with all the information needed for aggregation""" + pre_entities = [] + for idx, token_scores in enumerate(scores): + # Filter special_tokens, they should only occur + # at the sentence boundaries since we're not encoding pairs of + # sentences so we don't have to keep track of those. + if special_tokens_mask[idx]: + continue + + word = self.tokenizer.convert_ids_to_tokens(int(input_ids[idx])) + if offset_mapping is not None: + start_ind, end_ind = offset_mapping[idx] + word_ref = sentence[start_ind:end_ind] + is_subword = len(word_ref) != len(word) + + if int(input_ids[idx]) == self.tokenizer.unk_token_id: + word = word_ref + is_subword = False + else: + start_ind = None + end_ind = None + is_subword = False + + pre_entity = { + "word": word, + "scores": token_scores, + "start": start_ind, + "end": end_ind, + "index": idx, + "is_subword": is_subword, + } + pre_entities.append(pre_entity) + return pre_entities + + def aggregate(self, pre_entities: List[dict], aggregation_strategy: AggregationStrategy) -> List[dict]: + if aggregation_strategy in {AggregationStrategy.NONE, AggregationStrategy.SIMPLE}: + entities = [] + for pre_entity in pre_entities: + entity_idx = pre_entity["scores"].argmax() + score = pre_entity["scores"][entity_idx] + entity = { + "entity": self.model.config.id2label[entity_idx], + "score": score, + "index": pre_entity["index"], + "word": pre_entity["word"], + "start": pre_entity["start"], + "end": pre_entity["end"], + } + entities.append(entity) + else: + entities = self.aggregate_words(pre_entities, aggregation_strategy) + + if aggregation_strategy == AggregationStrategy.NONE: + return entities + return self.group_entities(entities) + + def aggregate_word(self, entities: List[dict], aggregation_strategy: AggregationStrategy) -> dict: + word = self.tokenizer.convert_tokens_to_string([entity["word"] for entity in entities]) + if aggregation_strategy == AggregationStrategy.FIRST: + scores = entities[0]["scores"] + idx = scores.argmax() + score = scores[idx] + entity = self.model.config.id2label[idx] + elif aggregation_strategy == AggregationStrategy.MAX: + max_entity = max(entities, key=lambda entity: entity["scores"].max()) + scores = max_entity["scores"] + idx = scores.argmax() + score = scores[idx] + entity = self.model.config.id2label[idx] + elif aggregation_strategy == AggregationStrategy.AVERAGE: + scores = np.stack([entity["scores"] for entity in entities]) + average_scores = np.nanmean(scores, axis=0) + entity_idx = average_scores.argmax() + entity = self.model.config.id2label[entity_idx] + score = average_scores[entity_idx] + else: + raise ValueError("Invalid aggregation_strategy") + new_entity = { + "entity": entity, + "score": score, + "word": word, + "start": entities[0]["start"], + "end": entities[-1]["end"], + } + return new_entity + + def aggregate_words(self, entities: List[dict], aggregation_strategy: AggregationStrategy) -> List[dict]: + """ + Override tokens from a given word that disagree to force agreement on word boundaries. + + Example: micro|soft| com|pany| B-ENT I-NAME I-ENT I-ENT will be rewritten with first strategy as microsoft| + company| B-ENT I-ENT + """ + assert aggregation_strategy not in { + AggregationStrategy.NONE, + AggregationStrategy.SIMPLE, + }, "NONE and SIMPLE strategies are invalid" + + word_entities = [] + word_group = None + for entity in entities: + if word_group is None: + word_group = [entity] + elif entity["is_subword"]: + word_group.append(entity) + else: + word_entities.append(self.aggregate_word(word_group, aggregation_strategy)) + word_group = [entity] + # Last item + word_entities.append(self.aggregate_word(word_group, aggregation_strategy)) + return word_entities + def group_sub_entities(self, entities: List[dict]) -> dict: """ Group together the adjacent tokens with the same entity predicted. @@ -249,6 +390,19 @@ class TokenClassificationPipeline(Pipeline): } return entity_group + def get_tag(self, entity_name: str) -> Tuple[str, str]: + if entity_name.startswith("B-"): + bi = "B" + tag = entity_name[2:] + elif entity_name.startswith("I-"): + bi = "I" + tag = entity_name[2:] + else: + # It's not in B-, I- format + bi = "B" + tag = entity_name + return bi, tag + def group_entities(self, entities: List[dict]) -> List[dict]: """ Find and group together the adjacent tokens with the same entity predicted. @@ -260,45 +414,29 @@ class TokenClassificationPipeline(Pipeline): entity_groups = [] entity_group_disagg = [] - if entities: - last_idx = entities[-1]["index"] - for entity in entities: - - is_last_idx = entity["index"] == last_idx - is_subword = self.ignore_subwords and entity["is_subword"] if not entity_group_disagg: - entity_group_disagg += [entity] - if is_last_idx: - entity_groups += [self.group_sub_entities(entity_group_disagg)] + entity_group_disagg.append(entity) continue - # If the current entity is similar and adjacent to the previous entity, append it to the disaggregated entity group - # The split is meant to account for the "B" and "I" suffixes + # If the current entity is similar and adjacent to the previous entity, + # append it to the disaggregated entity group + # The split is meant to account for the "B" and "I" prefixes # Shouldn't merge if both entities are B-type - if ( - ( - entity["entity"].split("-")[-1] == entity_group_disagg[-1]["entity"].split("-")[-1] - and entity["entity"].split("-")[0] != "B" - ) - and entity["index"] == entity_group_disagg[-1]["index"] + 1 - ) or is_subword: - # Modify subword type to be previous_type - if is_subword: - entity["entity"] = entity_group_disagg[-1]["entity"].split("-")[-1] - entity["score"] = np.nan # set ignored scores to nan and use np.nanmean + bi, tag = self.get_tag(entity["entity"]) + last_bi, last_tag = self.get_tag(entity_group_disagg[-1]["entity"]) - entity_group_disagg += [entity] - # Group the entities at the last entity - if is_last_idx: - entity_groups += [self.group_sub_entities(entity_group_disagg)] - # If the current entity is different from the previous entity, aggregate the disaggregated entity group + if tag == last_tag and bi != "B": + # Modify subword type to be previous_type + entity_group_disagg.append(entity) else: - entity_groups += [self.group_sub_entities(entity_group_disagg)] + # If the current entity is different from the previous entity + # aggregate the disaggregated entity group + entity_groups.append(self.group_sub_entities(entity_group_disagg)) entity_group_disagg = [entity] - # If it's the last entity, add it to the entity groups - if is_last_idx: - entity_groups += [self.group_sub_entities(entity_group_disagg)] + if entity_group_disagg: + # it's the last entity, add it to the entity groups + entity_groups.append(self.group_sub_entities(entity_group_disagg)) return entity_groups diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 4144be2eb9..81d74a9a42 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -1207,19 +1207,25 @@ def nested_simplify(obj, decimals=3): Simplifies an object by rounding float numbers, and downcasting tensors/numpy arrays to get simple equality test within tests. """ + import numpy as np + from transformers.tokenization_utils import BatchEncoding if isinstance(obj, list): return [nested_simplify(item, decimals) for item in obj] + elif isinstance(obj, np.ndarray): + return nested_simplify(obj.tolist()) elif isinstance(obj, (dict, BatchEncoding)): return {nested_simplify(k, decimals): nested_simplify(v, decimals) for k, v in obj.items()} - elif isinstance(obj, (str, int)): + elif isinstance(obj, (str, int, np.int64)): return obj elif is_torch_available() and isinstance(obj, torch.Tensor): - return nested_simplify(obj.tolist()) + return nested_simplify(obj.tolist(), decimals) elif is_tf_available() and tf.is_tensor(obj): return nested_simplify(obj.numpy().tolist()) elif isinstance(obj, float): return round(obj, decimals) + elif isinstance(obj, np.float32): + return nested_simplify(obj.item(), decimals) else: raise Exception(f"Not supported: {type(obj)}") diff --git a/tests/test_pipelines_token_classification.py b/tests/test_pipelines_token_classification.py index 756ccbf52d..d611509ce6 100644 --- a/tests/test_pipelines_token_classification.py +++ b/tests/test_pipelines_token_classification.py @@ -14,16 +14,15 @@ import unittest -from transformers import AutoTokenizer, is_torch_available, pipeline -from transformers.pipelines import Pipeline, TokenClassificationArgumentHandler -from transformers.testing_utils import require_tf, require_torch, slow +import numpy as np + +from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline +from transformers.pipelines import AggregationStrategy, Pipeline, TokenClassificationArgumentHandler +from transformers.testing_utils import nested_simplify, require_tf, require_torch, slow from .test_pipelines_common import CustomInputPipelineCommonMixin -if is_torch_available(): - import numpy as np - VALID_INPUTS = ["A simple string", ["list of strings", "A simple string that is quite a bit longer"]] @@ -35,210 +34,10 @@ class TokenClassificationPipelineTests(CustomInputPipelineCommonMixin, unittest. large_models = [] # Models tested with the @slow decorator def _test_pipeline(self, nlp: Pipeline): - output_keys = {"entity", "word", "score", "start", "end"} - if nlp.grouped_entities: + output_keys = {"entity", "word", "score", "start", "end", "index"} + if nlp.aggregation_strategy != AggregationStrategy.NONE: output_keys = {"entity_group", "word", "score", "start", "end"} - ungrouped_ner_inputs = [ - [ - { - "entity": "B-PER", - "index": 1, - "score": 0.9994944930076599, - "is_subword": False, - "word": "Cons", - "start": 0, - "end": 4, - }, - { - "entity": "B-PER", - "index": 2, - "score": 0.8025449514389038, - "is_subword": True, - "word": "##uelo", - "start": 4, - "end": 8, - }, - { - "entity": "I-PER", - "index": 3, - "score": 0.9993102550506592, - "is_subword": False, - "word": "Ara", - "start": 9, - "end": 11, - }, - { - "entity": "I-PER", - "index": 4, - "score": 0.9993743896484375, - "is_subword": True, - "word": "##új", - "start": 11, - "end": 13, - }, - { - "entity": "I-PER", - "index": 5, - "score": 0.9992871880531311, - "is_subword": True, - "word": "##o", - "start": 13, - "end": 14, - }, - { - "entity": "I-PER", - "index": 6, - "score": 0.9993029236793518, - "is_subword": False, - "word": "No", - "start": 15, - "end": 17, - }, - { - "entity": "I-PER", - "index": 7, - "score": 0.9981776475906372, - "is_subword": True, - "word": "##guera", - "start": 17, - "end": 22, - }, - { - "entity": "B-PER", - "index": 15, - "score": 0.9998136162757874, - "is_subword": False, - "word": "Andrés", - "start": 23, - "end": 28, - }, - { - "entity": "I-PER", - "index": 16, - "score": 0.999740719795227, - "is_subword": False, - "word": "Pas", - "start": 29, - "end": 32, - }, - { - "entity": "I-PER", - "index": 17, - "score": 0.9997414350509644, - "is_subword": True, - "word": "##tran", - "start": 32, - "end": 36, - }, - { - "entity": "I-PER", - "index": 18, - "score": 0.9996136426925659, - "is_subword": True, - "word": "##a", - "start": 36, - "end": 37, - }, - { - "entity": "B-ORG", - "index": 28, - "score": 0.9989739060401917, - "is_subword": False, - "word": "Far", - "start": 39, - "end": 42, - }, - { - "entity": "I-ORG", - "index": 29, - "score": 0.7188422083854675, - "is_subword": True, - "word": "##c", - "start": 42, - "end": 43, - }, - ], - [ - { - "entity": "I-PER", - "index": 1, - "score": 0.9968166351318359, - "is_subword": False, - "word": "En", - "start": 0, - "end": 2, - }, - { - "entity": "I-PER", - "index": 2, - "score": 0.9957635998725891, - "is_subword": True, - "word": "##zo", - "start": 2, - "end": 4, - }, - { - "entity": "I-ORG", - "index": 7, - "score": 0.9986497163772583, - "is_subword": False, - "word": "UN", - "start": 11, - "end": 13, - }, - ], - ] - - expected_grouped_ner_results = [ - [ - { - "entity_group": "PER", - "score": 0.999369223912557, - "word": "Consuelo Araújo Noguera", - "start": 0, - "end": 22, - }, - { - "entity_group": "PER", - "score": 0.9997771680355072, - "word": "Andrés Pastrana", - "start": 23, - "end": 37, - }, - {"entity_group": "ORG", "score": 0.9989739060401917, "word": "Farc", "start": 39, "end": 43}, - ], - [ - {"entity_group": "PER", "score": 0.9968166351318359, "word": "Enzo", "start": 0, "end": 4}, - {"entity_group": "ORG", "score": 0.9986497163772583, "word": "UN", "start": 11, "end": 13}, - ], - ] - - expected_grouped_ner_results_w_subword = [ - [ - {"entity_group": "PER", "score": 0.9994944930076599, "word": "Cons", "start": 0, "end": 4}, - { - "entity_group": "PER", - "score": 0.9663328925768534, - "word": "##uelo Araújo Noguera", - "start": 4, - "end": 22, - }, - { - "entity_group": "PER", - "score": 0.9997273534536362, - "word": "Andrés Pastrana", - "start": 23, - "end": 37, - }, - {"entity_group": "ORG", "score": 0.8589080572128296, "word": "Farc", "start": 39, "end": 43}, - ], - [ - {"entity_group": "PER", "score": 0.9962901175022125, "word": "Enzo", "start": 0, "end": 4}, - {"entity_group": "ORG", "score": 0.9986497163772583, "word": "UN", "start": 11, "end": 13}, - ], - ] - self.assertIsNotNone(nlp) mono_result = nlp(VALID_INPUTS[0]) @@ -262,15 +61,306 @@ class TokenClassificationPipelineTests(CustomInputPipelineCommonMixin, unittest. for key in output_keys: self.assertIn(key, result) - if nlp.grouped_entities: - if nlp.ignore_subwords: - for ungrouped_input, grouped_result in zip(ungrouped_ner_inputs, expected_grouped_ner_results): - self.assertEqual(nlp.group_entities(ungrouped_input), grouped_result) - else: - for ungrouped_input, grouped_result in zip( - ungrouped_ner_inputs, expected_grouped_ner_results_w_subword - ): - self.assertEqual(nlp.group_entities(ungrouped_input), grouped_result) + @require_torch + @slow + def test_spanish_bert(self): + # https://github.com/huggingface/transformers/pull/4987 + NER_MODEL = "mrm8488/bert-spanish-cased-finetuned-ner" + model = AutoModelForTokenClassification.from_pretrained(NER_MODEL) + tokenizer = AutoTokenizer.from_pretrained(NER_MODEL, use_fast=True) + sentence = """Consuelo Araújo Noguera, ministra de cultura del presidente Andrés Pastrana (1998.2002) fue asesinada por las Farc luego de haber permanecido secuestrada por algunos meses.""" + + token_classifier = pipeline("ner", model=model, tokenizer=tokenizer) + output = token_classifier(sentence) + self.assertEqual( + nested_simplify(output[:3]), + [ + {"entity": "B-PER", "score": 0.999, "word": "Cons", "start": 0, "end": 4, "index": 1}, + {"entity": "B-PER", "score": 0.803, "word": "##uelo", "start": 4, "end": 8, "index": 2}, + {"entity": "I-PER", "score": 0.999, "word": "Ara", "start": 9, "end": 12, "index": 3}, + ], + ) + + token_classifier = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="simple") + output = token_classifier(sentence) + self.assertEqual( + nested_simplify(output[:3]), + [ + {"entity_group": "PER", "score": 0.999, "word": "Cons", "start": 0, "end": 4}, + {"entity_group": "PER", "score": 0.966, "word": "##uelo Araújo Noguera", "start": 4, "end": 23}, + {"entity_group": "PER", "score": 1.0, "word": "Andrés Pastrana", "start": 60, "end": 75}, + ], + ) + + token_classifier = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="first") + output = token_classifier(sentence) + self.assertEqual( + nested_simplify(output[:3]), + [ + {"entity_group": "PER", "score": 0.999, "word": "Consuelo Araújo Noguera", "start": 0, "end": 23}, + {"entity_group": "PER", "score": 1.0, "word": "Andrés Pastrana", "start": 60, "end": 75}, + {"entity_group": "ORG", "score": 0.999, "word": "Farc", "start": 110, "end": 114}, + ], + ) + + token_classifier = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="max") + output = token_classifier(sentence) + self.assertEqual( + nested_simplify(output[:3]), + [ + {"entity_group": "PER", "score": 0.999, "word": "Consuelo Araújo Noguera", "start": 0, "end": 23}, + {"entity_group": "PER", "score": 1.0, "word": "Andrés Pastrana", "start": 60, "end": 75}, + {"entity_group": "ORG", "score": 0.999, "word": "Farc", "start": 110, "end": 114}, + ], + ) + + token_classifier = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="average") + output = token_classifier(sentence) + self.assertEqual( + nested_simplify(output[:3]), + [ + {"entity_group": "PER", "score": 0.966, "word": "Consuelo Araújo Noguera", "start": 0, "end": 23}, + {"entity_group": "PER", "score": 1.0, "word": "Andrés Pastrana", "start": 60, "end": 75}, + {"entity_group": "ORG", "score": 0.542, "word": "Farc", "start": 110, "end": 114}, + ], + ) + + @require_torch + @slow + def test_dbmdz_english(self): + # Other sentence + NER_MODEL = "dbmdz/bert-large-cased-finetuned-conll03-english" + model = AutoModelForTokenClassification.from_pretrained(NER_MODEL) + tokenizer = AutoTokenizer.from_pretrained(NER_MODEL, use_fast=True) + sentence = """Enzo works at the the UN""" + token_classifier = pipeline("ner", model=model, tokenizer=tokenizer) + output = token_classifier(sentence) + self.assertEqual( + nested_simplify(output), + [ + {"entity": "I-PER", "score": 0.997, "word": "En", "start": 0, "end": 2, "index": 1}, + {"entity": "I-PER", "score": 0.996, "word": "##zo", "start": 2, "end": 4, "index": 2}, + {"entity": "I-ORG", "score": 0.999, "word": "UN", "start": 22, "end": 24, "index": 7}, + ], + ) + + token_classifier = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="simple") + output = token_classifier(sentence) + self.assertEqual( + nested_simplify(output), + [ + {"entity_group": "PER", "score": 0.996, "word": "Enzo", "start": 0, "end": 4}, + {"entity_group": "ORG", "score": 0.999, "word": "UN", "start": 22, "end": 24}, + ], + ) + + token_classifier = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="first") + output = token_classifier(sentence) + self.assertEqual( + nested_simplify(output[:3]), + [ + {"entity_group": "PER", "score": 0.997, "word": "Enzo", "start": 0, "end": 4}, + {"entity_group": "ORG", "score": 0.999, "word": "UN", "start": 22, "end": 24}, + ], + ) + + token_classifier = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="max") + output = token_classifier(sentence) + self.assertEqual( + nested_simplify(output[:3]), + [ + {"entity_group": "PER", "score": 0.997, "word": "Enzo", "start": 0, "end": 4}, + {"entity_group": "ORG", "score": 0.999, "word": "UN", "start": 22, "end": 24}, + ], + ) + + token_classifier = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="average") + output = token_classifier(sentence) + self.assertEqual( + nested_simplify(output), + [ + {"entity_group": "PER", "score": 0.996, "word": "Enzo", "start": 0, "end": 4}, + {"entity_group": "ORG", "score": 0.999, "word": "UN", "start": 22, "end": 24}, + ], + ) + + @require_torch + def test_aggregation_strategy(self): + model_name = self.small_models[0] + tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) + token_classifier = pipeline(task="ner", model=model_name, tokenizer=tokenizer, framework="pt") + # Just to understand scores indexes in this test + self.assertEqual( + token_classifier.model.config.id2label, + {0: "O", 1: "B-MISC", 2: "I-MISC", 3: "B-PER", 4: "I-PER", 5: "B-ORG", 6: "I-ORG", 7: "B-LOC", 8: "I-LOC"}, + ) + example = [ + { + # fmt : off + "scores": np.array([0, 0, 0, 0, 0.9968166351318359, 0, 0, 0]), + "index": 1, + "is_subword": False, + "word": "En", + "start": 0, + "end": 2, + }, + { + # fmt : off + "scores": np.array([0, 0, 0, 0, 0.9957635998725891, 0, 0, 0]), + "index": 2, + "is_subword": True, + "word": "##zo", + "start": 2, + "end": 4, + }, + { + # fmt: off + "scores": np.array([0, 0, 0, 0, 0, 0.9986497163772583, 0, 0, ]), + # fmt: on + "index": 7, + "word": "UN", + "is_subword": False, + "start": 11, + "end": 13, + }, + ] + self.assertEqual( + nested_simplify(token_classifier.aggregate(example, AggregationStrategy.NONE)), + [ + {"end": 2, "entity": "I-PER", "score": 0.997, "start": 0, "word": "En", "index": 1}, + {"end": 4, "entity": "I-PER", "score": 0.996, "start": 2, "word": "##zo", "index": 2}, + {"end": 13, "entity": "B-ORG", "score": 0.999, "start": 11, "word": "UN", "index": 7}, + ], + ) + self.assertEqual( + nested_simplify(token_classifier.aggregate(example, AggregationStrategy.SIMPLE)), + [ + {"entity_group": "PER", "score": 0.996, "word": "Enzo", "start": 0, "end": 4}, + {"entity_group": "ORG", "score": 0.999, "word": "UN", "start": 11, "end": 13}, + ], + ) + self.assertEqual( + nested_simplify(token_classifier.aggregate(example, AggregationStrategy.FIRST)), + [ + {"entity_group": "PER", "score": 0.997, "word": "Enzo", "start": 0, "end": 4}, + {"entity_group": "ORG", "score": 0.999, "word": "UN", "start": 11, "end": 13}, + ], + ) + self.assertEqual( + nested_simplify(token_classifier.aggregate(example, AggregationStrategy.MAX)), + [ + {"entity_group": "PER", "score": 0.997, "word": "Enzo", "start": 0, "end": 4}, + {"entity_group": "ORG", "score": 0.999, "word": "UN", "start": 11, "end": 13}, + ], + ) + self.assertEqual( + nested_simplify(token_classifier.aggregate(example, AggregationStrategy.AVERAGE)), + [ + {"entity_group": "PER", "score": 0.996, "word": "Enzo", "start": 0, "end": 4}, + {"entity_group": "ORG", "score": 0.999, "word": "UN", "start": 11, "end": 13}, + ], + ) + + @require_torch + def test_aggregation_strategy_example2(self): + model_name = self.small_models[0] + tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) + token_classifier = pipeline(task="ner", model=model_name, tokenizer=tokenizer, framework="pt") + # Just to understand scores indexes in this test + self.assertEqual( + token_classifier.model.config.id2label, + {0: "O", 1: "B-MISC", 2: "I-MISC", 3: "B-PER", 4: "I-PER", 5: "B-ORG", 6: "I-ORG", 7: "B-LOC", 8: "I-LOC"}, + ) + example = [ + { + # Necessary for AVERAGE + "scores": np.array([0, 0.55, 0, 0.45, 0, 0, 0, 0, 0, 0]), + "is_subword": False, + "index": 1, + "word": "Ra", + "start": 0, + "end": 2, + }, + { + "scores": np.array([0, 0, 0, 0.2, 0, 0, 0, 0.8, 0, 0]), + "is_subword": True, + "word": "##ma", + "start": 2, + "end": 4, + "index": 2, + }, + { + # 4th score will have the higher average + # 4th score is B-PER for this model + # It's does not correspond to any of the subtokens. + "scores": np.array([0, 0, 0, 0.4, 0, 0, 0.6, 0, 0, 0]), + "is_subword": True, + "word": "##zotti", + "start": 11, + "end": 13, + "index": 3, + }, + ] + self.assertEqual( + token_classifier.aggregate(example, AggregationStrategy.NONE), + [ + {"end": 2, "entity": "B-MISC", "score": 0.55, "start": 0, "word": "Ra", "index": 1}, + {"end": 4, "entity": "B-LOC", "score": 0.8, "start": 2, "word": "##ma", "index": 2}, + {"end": 13, "entity": "I-ORG", "score": 0.6, "start": 11, "word": "##zotti", "index": 3}, + ], + ) + + self.assertEqual( + token_classifier.aggregate(example, AggregationStrategy.FIRST), + [{"entity_group": "MISC", "score": 0.55, "word": "Ramazotti", "start": 0, "end": 13}], + ) + self.assertEqual( + token_classifier.aggregate(example, AggregationStrategy.MAX), + [{"entity_group": "LOC", "score": 0.8, "word": "Ramazotti", "start": 0, "end": 13}], + ) + self.assertEqual( + nested_simplify(token_classifier.aggregate(example, AggregationStrategy.AVERAGE)), + [{"entity_group": "PER", "score": 0.35, "word": "Ramazotti", "start": 0, "end": 13}], + ) + + @require_torch + def test_gather_pre_entities(self): + + model_name = self.small_models[0] + tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) + nlp = pipeline(task="ner", model=model_name, tokenizer=tokenizer, framework="pt") + + sentence = "Hello there" + + tokens = tokenizer( + sentence, + return_attention_mask=False, + return_tensors="pt", + truncation=True, + return_special_tokens_mask=True, + return_offsets_mapping=True, + ) + offset_mapping = tokens.pop("offset_mapping").cpu().numpy()[0] + special_tokens_mask = tokens.pop("special_tokens_mask").cpu().numpy()[0] + input_ids = tokens["input_ids"].numpy()[0] + # First element in [CLS] + scores = np.array([[1, 0, 0], [0.1, 0.3, 0.6], [0.8, 0.1, 0.1]]) + + pre_entities = nlp.gather_pre_entities(sentence, input_ids, scores, offset_mapping, special_tokens_mask) + self.assertEqual( + nested_simplify(pre_entities), + [ + {"word": "Hello", "scores": [0.1, 0.3, 0.6], "start": 0, "end": 5, "is_subword": False, "index": 1}, + { + "word": "there", + "scores": [0.8, 0.1, 0.1], + "index": 2, + "start": 6, + "end": 11, + "is_subword": False, + }, + ], + ) @require_tf def test_tf_only(self): @@ -295,8 +385,7 @@ class TokenClassificationPipelineTests(CustomInputPipelineCommonMixin, unittest. model=model_name, tokenizer=tokenizer, framework="tf", - grouped_entities=True, - ignore_subwords=True, + aggregation_strategy=AggregationStrategy.FIRST, ) self._test_pipeline(nlp) @@ -307,18 +396,23 @@ class TokenClassificationPipelineTests(CustomInputPipelineCommonMixin, unittest. model=model_name, tokenizer=tokenizer, framework="tf", - grouped_entities=True, - ignore_subwords=False, + aggregation_strategy=AggregationStrategy.SIMPLE, ) self._test_pipeline(nlp) @require_torch def test_pt_ignore_subwords_slow_tokenizer_raises(self): - for model_name in self.small_models: - tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) + model_name = self.small_models[0] + tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) - with self.assertRaises(ValueError): - pipeline(task="ner", model=model_name, tokenizer=tokenizer, ignore_subwords=True, use_fast=False) + with self.assertRaises(ValueError): + pipeline(task="ner", model=model_name, tokenizer=tokenizer, aggregation_strategy=AggregationStrategy.FIRST) + with self.assertRaises(ValueError): + pipeline( + task="ner", model=model_name, tokenizer=tokenizer, aggregation_strategy=AggregationStrategy.AVERAGE + ) + with self.assertRaises(ValueError): + pipeline(task="ner", model=model_name, tokenizer=tokenizer, aggregation_strategy=AggregationStrategy.MAX) @require_torch def test_pt_defaults_slow_tokenizer(self): @@ -333,27 +427,27 @@ class TokenClassificationPipelineTests(CustomInputPipelineCommonMixin, unittest. nlp = pipeline(task="ner", model=model_name) self._test_pipeline(nlp) + @slow + @require_torch + def test_warnings(self): + with self.assertWarns(UserWarning): + token_classifier = pipeline(task="ner", model=self.small_models[0], grouped_entities=True) + self.assertEqual(token_classifier.aggregation_strategy, AggregationStrategy.SIMPLE) + with self.assertWarns(UserWarning): + token_classifier = pipeline( + task="ner", model=self.small_models[0], grouped_entities=True, ignore_subwords=True + ) + self.assertEqual(token_classifier.aggregation_strategy, AggregationStrategy.FIRST) + @slow @require_torch def test_simple(self): - nlp = pipeline(task="ner", model="dslim/bert-base-NER", grouped_entities=True) + nlp = pipeline(task="ner", model="dslim/bert-base-NER", aggregation_strategy=AggregationStrategy.SIMPLE) sentence = "Hello Sarah Jessica Parker who Jessica lives in New York" sentence2 = "This is a simple test" output = nlp(sentence) - def simplify(output): - if isinstance(output, (list, tuple)): - return [simplify(item) for item in output] - elif isinstance(output, dict): - return {simplify(k): simplify(v) for k, v in output.items()} - elif isinstance(output, (str, int, np.int64)): - return output - elif isinstance(output, float): - return round(output, 3) - else: - raise Exception(f"Cannot handle {type(output)}") - - output_ = simplify(output) + output_ = nested_simplify(output) self.assertEqual( output_, @@ -371,7 +465,7 @@ class TokenClassificationPipelineTests(CustomInputPipelineCommonMixin, unittest. ) output = nlp([sentence, sentence2]) - output_ = simplify(output) + output_ = nested_simplify(output) self.assertEqual( output_, @@ -390,14 +484,14 @@ class TokenClassificationPipelineTests(CustomInputPipelineCommonMixin, unittest. for model_name in self.small_models: tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) nlp = pipeline( - task="ner", model=model_name, tokenizer=tokenizer, grouped_entities=True, ignore_subwords=True + task="ner", model=model_name, tokenizer=tokenizer, aggregation_strategy=AggregationStrategy.FIRST ) self._test_pipeline(nlp) for model_name in self.small_models: tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) nlp = pipeline( - task="ner", model=model_name, tokenizer=tokenizer, grouped_entities=True, ignore_subwords=False + task="ner", model=model_name, tokenizer=tokenizer, aggregation_strategy=AggregationStrategy.SIMPLE ) self._test_pipeline(nlp)