Chunkable token classification pipeline (#21771)

* Chunkable classification pipeline 

The TokenClassificationPipeline is now able to process sequences longer than 512. No matter the framework, the model, the tokenizer. We just have to pass process_all=True and a stride number (optional). The behavior remains the same if you don't pass these optional parameters. For overlapping parts when using stride above 0, we consider only the max scores for each overlapped token in all chunks where the token is.

* Update token_classification.py

* Update token_classification.py

* Update token_classification.py

* Update token_classification.py

* Update token_classification.py

* Update token_classification.py

* Update token_classification.py

* Update token_classification.py

* Update token_classification.py

* Update token_classification.py

* Update token_classification.py

* Update token_classification.py

* update with latest black format

* update black format

* Update token_classification.py

* Update token_classification.py

* format correction

* Update token_classification.py

* Update token_classification.py

* Update token_classification.py

* Update token_classification.py

* Update comments

* Update src/transformers/pipelines/token_classification.py

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>

* Update token_classification.py

Correct spaces, remove process_all and keep only stride. If stride is provided, the pipeline is applied to the whole text.

* Update token_classification.py

* Update token_classification.py

* Update token_classification.py

* Update token_classification.py

* Update token_classification.py

* Update token_classification.py

* Update token_classification.py

* Update token_classification.py

* Update chunk aggregation

Update the chunk aggregation strategy based on entities aggregation.

* Update token_classification.py

* Update token_classification.py

* Update token_classification.py

* Update token_classification.py

* Update token_classification.py

* Update token_classification.py

* Update token_classification.py

* Update token_classification.py

* Update token_classification.py

* Update token_classification.py

* Update token_classification.py

* Update token_classification.py

Remove unnecessary pop from outputs dict

* Update token_classification.py

* Update token_classification.py

* Update token_classification.py

* Update token_classification.py

* Update token_classification.py

* Update token_classification.py

* Update token_classification.py

* Update src/transformers/pipelines/token_classification.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* add chunking tests

* correct formating

* correct formatting

* correct model id for test chunking

* update scores with nested simplify

* Update test_pipelines_token_classification.py

* Update test_pipelines_token_classification.py

* update model to a tiny one

* Update test_pipelines_token_classification.py

* Adding smaller test for chunking.

* Fixup

* Update token_classification.py

* Update src/transformers/pipelines/token_classification.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/pipelines/token_classification.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Luc CAILLIAU 2023-03-22 19:13:20 +01:00 committed by GitHub
parent f48d3314e4
commit d62e7d8842
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 231 additions and 40 deletions

View File

@ -5,13 +5,19 @@ from typing import List, Optional, Tuple, Union
import numpy as np
from ..models.bert.tokenization_bert import BasicTokenizer
from ..utils import ExplicitEnum, add_end_docstrings, is_tf_available, is_torch_available
from .base import PIPELINE_INIT_ARGS, ArgumentHandler, Dataset, Pipeline
from ..utils import (
ExplicitEnum,
add_end_docstrings,
is_tf_available,
is_torch_available,
)
from .base import PIPELINE_INIT_ARGS, ArgumentHandler, ChunkPipeline, Dataset
if is_tf_available():
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
import tensorflow as tf
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
if is_torch_available():
from ..models.auto.modeling_auto import MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
@ -60,6 +66,9 @@ class AggregationStrategy(ExplicitEnum):
grouped_entities (`bool`, *optional*, defaults to `False`):
DEPRECATED, use `aggregation_strategy` instead. Whether or not to group the tokens corresponding to the
same entity together in the predictions or not.
stride (`int`, *optional*):
If stride is provided, the pipeline is applied on all the text. The text is split into chunks of size
model_max_length. Works only with fast tokenizers and `aggregation_strategy` different from `NONE`.
aggregation_strategy (`str`, *optional*, defaults to `"none"`):
The strategy to fuse (or not) tokens based on the model prediction.
@ -82,7 +91,7 @@ class AggregationStrategy(ExplicitEnum):
end up with different tags. Word entity will simply be the token with the maximum score.
""",
)
class TokenClassificationPipeline(Pipeline):
class TokenClassificationPipeline(ChunkPipeline):
"""
Named Entity Recognition pipeline using any `ModelForTokenClassification`. See the [named entity recognition
examples](../task_summary#named-entity-recognition) for more information.
@ -139,6 +148,7 @@ class TokenClassificationPipeline(Pipeline):
ignore_subwords: Optional[bool] = None,
aggregation_strategy: Optional[AggregationStrategy] = None,
offset_mapping: Optional[List[Tuple[int, int]]] = None,
stride: Optional[int] = None,
):
preprocess_params = {}
if offset_mapping is not None:
@ -174,11 +184,30 @@ class TokenClassificationPipeline(Pipeline):
):
raise ValueError(
"Slow tokenizers cannot handle subwords. Please set the `aggregation_strategy` option"
'to `"simple"` or use a fast tokenizer.'
' to `"simple"` or use a fast tokenizer.'
)
postprocess_params["aggregation_strategy"] = aggregation_strategy
if ignore_labels is not None:
postprocess_params["ignore_labels"] = ignore_labels
if stride is not None:
if aggregation_strategy == AggregationStrategy.NONE:
raise ValueError(
"`stride` was provided to process all the text but `aggregation_strategy="
f'"{aggregation_strategy}"`, please select another one instead.'
)
else:
if self.tokenizer.is_fast:
tokenizer_params = {
"return_overflowing_tokens": True,
"padding": True,
"stride": stride,
}
preprocess_params["tokenizer_params"] = tokenizer_params
else:
raise ValueError(
"`stride` was provided to process all the text but you're using a slow tokenizer."
" Please use a fast tokenizer."
)
return preprocess_params, {}, postprocess_params
def __call__(self, inputs: Union[str, List[str]], **kwargs):
@ -213,29 +242,40 @@ class TokenClassificationPipeline(Pipeline):
return super().__call__(inputs, **kwargs)
def preprocess(self, sentence, offset_mapping=None):
def preprocess(self, sentence, offset_mapping=None, **preprocess_params):
tokenizer_params = preprocess_params.pop("tokenizer_params", {})
truncation = True if self.tokenizer.model_max_length and self.tokenizer.model_max_length > 0 else False
model_inputs = self.tokenizer(
inputs = self.tokenizer(
sentence,
return_tensors=self.framework,
truncation=truncation,
return_special_tokens_mask=True,
return_offsets_mapping=self.tokenizer.is_fast,
**tokenizer_params,
)
if offset_mapping:
model_inputs["offset_mapping"] = offset_mapping
inputs.pop("overflow_to_sample_mapping", None)
num_chunks = len(inputs["input_ids"])
model_inputs["sentence"] = sentence
for i in range(num_chunks):
if self.framework == "tf":
model_inputs = {k: tf.expand_dims(v[i], 0) for k, v in inputs.items()}
else:
model_inputs = {k: v[i].unsqueeze(0) for k, v in inputs.items()}
if offset_mapping is not None:
model_inputs["offset_mapping"] = offset_mapping
model_inputs["sentence"] = sentence if i == 0 else None
model_inputs["is_last"] = i == num_chunks - 1
return model_inputs
yield model_inputs
def _forward(self, model_inputs):
# Forward
special_tokens_mask = model_inputs.pop("special_tokens_mask")
offset_mapping = model_inputs.pop("offset_mapping", None)
sentence = model_inputs.pop("sentence")
is_last = model_inputs.pop("is_last")
if self.framework == "tf":
logits = self.model(model_inputs.data)[0]
logits = self.model(**model_inputs)[0]
else:
output = self.model(**model_inputs)
logits = output["logits"] if isinstance(output, dict) else output[0]
@ -245,38 +285,67 @@ class TokenClassificationPipeline(Pipeline):
"special_tokens_mask": special_tokens_mask,
"offset_mapping": offset_mapping,
"sentence": sentence,
"is_last": is_last,
**model_inputs,
}
def postprocess(self, model_outputs, aggregation_strategy=AggregationStrategy.NONE, ignore_labels=None):
def postprocess(self, all_outputs, aggregation_strategy=AggregationStrategy.NONE, ignore_labels=None):
if ignore_labels is None:
ignore_labels = ["O"]
logits = model_outputs["logits"][0].numpy()
sentence = model_outputs["sentence"]
input_ids = model_outputs["input_ids"][0]
offset_mapping = model_outputs["offset_mapping"][0] if model_outputs["offset_mapping"] is not None else None
special_tokens_mask = model_outputs["special_tokens_mask"][0].numpy()
all_entities = []
for model_outputs in all_outputs:
logits = model_outputs["logits"][0].numpy()
sentence = all_outputs[0]["sentence"]
input_ids = model_outputs["input_ids"][0]
offset_mapping = (
model_outputs["offset_mapping"][0] if model_outputs["offset_mapping"] is not None else None
)
special_tokens_mask = model_outputs["special_tokens_mask"][0].numpy()
maxes = np.max(logits, axis=-1, keepdims=True)
shifted_exp = np.exp(logits - maxes)
scores = shifted_exp / shifted_exp.sum(axis=-1, keepdims=True)
maxes = np.max(logits, axis=-1, keepdims=True)
shifted_exp = np.exp(logits - maxes)
scores = shifted_exp / shifted_exp.sum(axis=-1, keepdims=True)
if self.framework == "tf":
input_ids = input_ids.numpy()
offset_mapping = offset_mapping.numpy() if offset_mapping is not None else None
if self.framework == "tf":
input_ids = input_ids.numpy()
offset_mapping = offset_mapping.numpy() if offset_mapping is not None else None
pre_entities = self.gather_pre_entities(
sentence, input_ids, scores, offset_mapping, special_tokens_mask, aggregation_strategy
)
grouped_entities = self.aggregate(pre_entities, aggregation_strategy)
# Filter anything that is in self.ignore_labels
entities = [
entity
for entity in grouped_entities
if entity.get("entity", None) not in ignore_labels
and entity.get("entity_group", None) not in ignore_labels
]
return entities
pre_entities = self.gather_pre_entities(
sentence, input_ids, scores, offset_mapping, special_tokens_mask, aggregation_strategy
)
grouped_entities = self.aggregate(pre_entities, aggregation_strategy)
# Filter anything that is in self.ignore_labels
entities = [
entity
for entity in grouped_entities
if entity.get("entity", None) not in ignore_labels
and entity.get("entity_group", None) not in ignore_labels
]
all_entities.extend(entities)
num_chunks = len(all_outputs)
if num_chunks > 1:
all_entities = self.aggregate_overlapping_entities(all_entities)
return all_entities
def aggregate_overlapping_entities(self, entities):
if len(entities) == 0:
return entities
entities = sorted(entities, key=lambda x: x["start"])
aggregated_entities = []
previous_entity = entities[0]
for entity in entities:
if previous_entity["start"] <= entity["start"] < previous_entity["end"]:
current_length = entity["end"] - entity["start"]
previous_length = previous_entity["end"] - previous_entity["start"]
if current_length > previous_length:
previous_entity = entity
elif current_length == previous_length and entity["score"] > previous_entity["score"]:
previous_entity = entity
else:
aggregated_entities.append(previous_entity)
previous_entity = entity
aggregated_entities.append(previous_entity)
return aggregated_entities
def gather_pre_entities(
self,
@ -290,9 +359,7 @@ class TokenClassificationPipeline(Pipeline):
"""Fuse various numpy arrays into dicts with all the information needed for aggregation"""
pre_entities = []
for idx, token_scores in enumerate(scores):
# Filter special_tokens, they should only occur
# at the sentence boundaries since we're not encoding pairs of
# sentences so we don't have to keep track of those.
# Filter special_tokens
if special_tokens_mask[idx]:
continue
@ -317,7 +384,10 @@ class TokenClassificationPipeline(Pipeline):
AggregationStrategy.AVERAGE,
AggregationStrategy.MAX,
}:
warnings.warn("Tokenizer does not support real words, using fallback heuristic", UserWarning)
warnings.warn(
"Tokenizer does not support real words, using fallback heuristic",
UserWarning,
)
is_subword = start_ind > 0 and " " not in sentence[start_ind - 1 : start_ind + 1]
if int(input_ids[idx]) == self.tokenizer.unk_token_id:

View File

@ -196,6 +196,127 @@ class TokenClassificationPipelineTests(unittest.TestCase):
)
self.assertEqual(token_classifier._postprocess_params["aggregation_strategy"], AggregationStrategy.FIRST)
@slow
@require_torch
def test_chunking(self):
NER_MODEL = "elastic/distilbert-base-uncased-finetuned-conll03-english"
model = AutoModelForTokenClassification.from_pretrained(NER_MODEL)
tokenizer = AutoTokenizer.from_pretrained(NER_MODEL, use_fast=True)
tokenizer.model_max_length = 10
stride = 5
sentence = (
"Hugging Face, Inc. is a French company that develops tools for building applications using machine learning. "
"The company, based in New York City was founded in 2016 by French entrepreneurs Clément Delangue, Julien Chaumond, and Thomas Wolf."
)
token_classifier = TokenClassificationPipeline(
model=model, tokenizer=tokenizer, aggregation_strategy="simple", stride=stride
)
output = token_classifier(sentence)
self.assertEqual(
nested_simplify(output),
[
{"entity_group": "ORG", "score": 0.978, "word": "hugging face, inc.", "start": 0, "end": 18},
{"entity_group": "MISC", "score": 0.999, "word": "french", "start": 24, "end": 30},
{"entity_group": "LOC", "score": 0.997, "word": "new york city", "start": 131, "end": 144},
{"entity_group": "MISC", "score": 0.999, "word": "french", "start": 168, "end": 174},
{"entity_group": "PER", "score": 0.999, "word": "clement delangue", "start": 189, "end": 205},
{"entity_group": "PER", "score": 0.999, "word": "julien chaumond", "start": 207, "end": 222},
{"entity_group": "PER", "score": 0.999, "word": "thomas wolf", "start": 228, "end": 239},
],
)
token_classifier = TokenClassificationPipeline(
model=model, tokenizer=tokenizer, aggregation_strategy="first", stride=stride
)
output = token_classifier(sentence)
self.assertEqual(
nested_simplify(output),
[
{"entity_group": "ORG", "score": 0.978, "word": "hugging face, inc.", "start": 0, "end": 18},
{"entity_group": "MISC", "score": 0.999, "word": "french", "start": 24, "end": 30},
{"entity_group": "LOC", "score": 0.997, "word": "new york city", "start": 131, "end": 144},
{"entity_group": "MISC", "score": 0.999, "word": "french", "start": 168, "end": 174},
{"entity_group": "PER", "score": 0.999, "word": "clement delangue", "start": 189, "end": 205},
{"entity_group": "PER", "score": 0.999, "word": "julien chaumond", "start": 207, "end": 222},
{"entity_group": "PER", "score": 0.999, "word": "thomas wolf", "start": 228, "end": 239},
],
)
token_classifier = TokenClassificationPipeline(
model=model, tokenizer=tokenizer, aggregation_strategy="max", stride=stride
)
output = token_classifier(sentence)
self.assertEqual(
nested_simplify(output),
[
{"entity_group": "ORG", "score": 0.978, "word": "hugging face, inc.", "start": 0, "end": 18},
{"entity_group": "MISC", "score": 0.999, "word": "french", "start": 24, "end": 30},
{"entity_group": "LOC", "score": 0.997, "word": "new york city", "start": 131, "end": 144},
{"entity_group": "MISC", "score": 0.999, "word": "french", "start": 168, "end": 174},
{"entity_group": "PER", "score": 0.999, "word": "clement delangue", "start": 189, "end": 205},
{"entity_group": "PER", "score": 0.999, "word": "julien chaumond", "start": 207, "end": 222},
{"entity_group": "PER", "score": 0.999, "word": "thomas wolf", "start": 228, "end": 239},
],
)
token_classifier = TokenClassificationPipeline(
model=model, tokenizer=tokenizer, aggregation_strategy="average", stride=stride
)
output = token_classifier(sentence)
self.assertEqual(
nested_simplify(output),
[
{"entity_group": "ORG", "score": 0.978, "word": "hugging face, inc.", "start": 0, "end": 18},
{"entity_group": "MISC", "score": 0.999, "word": "french", "start": 24, "end": 30},
{"entity_group": "LOC", "score": 0.997, "word": "new york city", "start": 131, "end": 144},
{"entity_group": "MISC", "score": 0.999, "word": "french", "start": 168, "end": 174},
{"entity_group": "PER", "score": 0.999, "word": "clement delangue", "start": 189, "end": 205},
{"entity_group": "PER", "score": 0.999, "word": "julien chaumond", "start": 207, "end": 222},
{"entity_group": "PER", "score": 0.999, "word": "thomas wolf", "start": 228, "end": 239},
],
)
@require_torch
def test_chunking_fast(self):
# Note: We cannot run the test on "conflicts" on the chunking.
# The problem is that the model is random, and thus the results do heavily
# depend on the chunking, so we cannot expect "abcd" and "bcd" to find
# the same entities. We defer to slow tests for this.
pipe = pipeline(model="hf-internal-testing/tiny-bert-for-token-classification")
sentence = "The company, based in New York City was founded in 2016 by French entrepreneurs"
results = pipe(sentence, aggregation_strategy="first")
# This is what this random model gives on the full sentence
self.assertEqual(
nested_simplify(results),
[
# This is 2 actual tokens
{"end": 39, "entity_group": "MISC", "score": 0.115, "start": 31, "word": "city was"},
{"end": 79, "entity_group": "MISC", "score": 0.115, "start": 66, "word": "entrepreneurs"},
],
)
# This will force the tokenizer to split after "city was".
pipe.tokenizer.model_max_length = 12
self.assertEqual(
pipe.tokenizer.decode(pipe.tokenizer.encode(sentence, truncation=True)),
"[CLS] the company, based in new york city was [SEP]",
)
stride = 4
results = pipe(sentence, aggregation_strategy="first", stride=stride)
self.assertEqual(
nested_simplify(results),
[
{"end": 39, "entity_group": "MISC", "score": 0.115, "start": 31, "word": "city was"},
# This is an extra entity found by this random model, but at least both original
# entities are there
{"end": 58, "entity_group": "MISC", "score": 0.115, "start": 56, "word": "by"},
{"end": 79, "entity_group": "MISC", "score": 0.115, "start": 66, "word": "entrepreneurs"},
],
)
@require_torch
@slow
def test_spanish_bert(self):