Chunkable token classification pipeline (#21771)
* Chunkable classification pipeline The TokenClassificationPipeline is now able to process sequences longer than 512. No matter the framework, the model, the tokenizer. We just have to pass process_all=True and a stride number (optional). The behavior remains the same if you don't pass these optional parameters. For overlapping parts when using stride above 0, we consider only the max scores for each overlapped token in all chunks where the token is. * Update token_classification.py * Update token_classification.py * Update token_classification.py * Update token_classification.py * Update token_classification.py * Update token_classification.py * Update token_classification.py * Update token_classification.py * Update token_classification.py * Update token_classification.py * Update token_classification.py * Update token_classification.py * update with latest black format * update black format * Update token_classification.py * Update token_classification.py * format correction * Update token_classification.py * Update token_classification.py * Update token_classification.py * Update token_classification.py * Update comments * Update src/transformers/pipelines/token_classification.py Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com> * Update token_classification.py Correct spaces, remove process_all and keep only stride. If stride is provided, the pipeline is applied to the whole text. * Update token_classification.py * Update token_classification.py * Update token_classification.py * Update token_classification.py * Update token_classification.py * Update token_classification.py * Update token_classification.py * Update token_classification.py * Update chunk aggregation Update the chunk aggregation strategy based on entities aggregation. * Update token_classification.py * Update token_classification.py * Update token_classification.py * Update token_classification.py * Update token_classification.py * Update token_classification.py * Update token_classification.py * Update token_classification.py * Update token_classification.py * Update token_classification.py * Update token_classification.py * Update token_classification.py Remove unnecessary pop from outputs dict * Update token_classification.py * Update token_classification.py * Update token_classification.py * Update token_classification.py * Update token_classification.py * Update token_classification.py * Update token_classification.py * Update src/transformers/pipelines/token_classification.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * add chunking tests * correct formating * correct formatting * correct model id for test chunking * update scores with nested simplify * Update test_pipelines_token_classification.py * Update test_pipelines_token_classification.py * update model to a tiny one * Update test_pipelines_token_classification.py * Adding smaller test for chunking. * Fixup * Update token_classification.py * Update src/transformers/pipelines/token_classification.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/pipelines/token_classification.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --------- Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
f48d3314e4
commit
d62e7d8842
|
@ -5,13 +5,19 @@ from typing import List, Optional, Tuple, Union
|
|||
import numpy as np
|
||||
|
||||
from ..models.bert.tokenization_bert import BasicTokenizer
|
||||
from ..utils import ExplicitEnum, add_end_docstrings, is_tf_available, is_torch_available
|
||||
from .base import PIPELINE_INIT_ARGS, ArgumentHandler, Dataset, Pipeline
|
||||
from ..utils import (
|
||||
ExplicitEnum,
|
||||
add_end_docstrings,
|
||||
is_tf_available,
|
||||
is_torch_available,
|
||||
)
|
||||
from .base import PIPELINE_INIT_ARGS, ArgumentHandler, ChunkPipeline, Dataset
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
|
||||
import tensorflow as tf
|
||||
|
||||
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
|
||||
if is_torch_available():
|
||||
from ..models.auto.modeling_auto import MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
|
||||
|
||||
|
@ -60,6 +66,9 @@ class AggregationStrategy(ExplicitEnum):
|
|||
grouped_entities (`bool`, *optional*, defaults to `False`):
|
||||
DEPRECATED, use `aggregation_strategy` instead. Whether or not to group the tokens corresponding to the
|
||||
same entity together in the predictions or not.
|
||||
stride (`int`, *optional*):
|
||||
If stride is provided, the pipeline is applied on all the text. The text is split into chunks of size
|
||||
model_max_length. Works only with fast tokenizers and `aggregation_strategy` different from `NONE`.
|
||||
aggregation_strategy (`str`, *optional*, defaults to `"none"`):
|
||||
The strategy to fuse (or not) tokens based on the model prediction.
|
||||
|
||||
|
@ -82,7 +91,7 @@ class AggregationStrategy(ExplicitEnum):
|
|||
end up with different tags. Word entity will simply be the token with the maximum score.
|
||||
""",
|
||||
)
|
||||
class TokenClassificationPipeline(Pipeline):
|
||||
class TokenClassificationPipeline(ChunkPipeline):
|
||||
"""
|
||||
Named Entity Recognition pipeline using any `ModelForTokenClassification`. See the [named entity recognition
|
||||
examples](../task_summary#named-entity-recognition) for more information.
|
||||
|
@ -139,6 +148,7 @@ class TokenClassificationPipeline(Pipeline):
|
|||
ignore_subwords: Optional[bool] = None,
|
||||
aggregation_strategy: Optional[AggregationStrategy] = None,
|
||||
offset_mapping: Optional[List[Tuple[int, int]]] = None,
|
||||
stride: Optional[int] = None,
|
||||
):
|
||||
preprocess_params = {}
|
||||
if offset_mapping is not None:
|
||||
|
@ -174,11 +184,30 @@ class TokenClassificationPipeline(Pipeline):
|
|||
):
|
||||
raise ValueError(
|
||||
"Slow tokenizers cannot handle subwords. Please set the `aggregation_strategy` option"
|
||||
'to `"simple"` or use a fast tokenizer.'
|
||||
' to `"simple"` or use a fast tokenizer.'
|
||||
)
|
||||
postprocess_params["aggregation_strategy"] = aggregation_strategy
|
||||
if ignore_labels is not None:
|
||||
postprocess_params["ignore_labels"] = ignore_labels
|
||||
if stride is not None:
|
||||
if aggregation_strategy == AggregationStrategy.NONE:
|
||||
raise ValueError(
|
||||
"`stride` was provided to process all the text but `aggregation_strategy="
|
||||
f'"{aggregation_strategy}"`, please select another one instead.'
|
||||
)
|
||||
else:
|
||||
if self.tokenizer.is_fast:
|
||||
tokenizer_params = {
|
||||
"return_overflowing_tokens": True,
|
||||
"padding": True,
|
||||
"stride": stride,
|
||||
}
|
||||
preprocess_params["tokenizer_params"] = tokenizer_params
|
||||
else:
|
||||
raise ValueError(
|
||||
"`stride` was provided to process all the text but you're using a slow tokenizer."
|
||||
" Please use a fast tokenizer."
|
||||
)
|
||||
return preprocess_params, {}, postprocess_params
|
||||
|
||||
def __call__(self, inputs: Union[str, List[str]], **kwargs):
|
||||
|
@ -213,29 +242,40 @@ class TokenClassificationPipeline(Pipeline):
|
|||
|
||||
return super().__call__(inputs, **kwargs)
|
||||
|
||||
def preprocess(self, sentence, offset_mapping=None):
|
||||
def preprocess(self, sentence, offset_mapping=None, **preprocess_params):
|
||||
tokenizer_params = preprocess_params.pop("tokenizer_params", {})
|
||||
truncation = True if self.tokenizer.model_max_length and self.tokenizer.model_max_length > 0 else False
|
||||
model_inputs = self.tokenizer(
|
||||
inputs = self.tokenizer(
|
||||
sentence,
|
||||
return_tensors=self.framework,
|
||||
truncation=truncation,
|
||||
return_special_tokens_mask=True,
|
||||
return_offsets_mapping=self.tokenizer.is_fast,
|
||||
**tokenizer_params,
|
||||
)
|
||||
if offset_mapping:
|
||||
model_inputs["offset_mapping"] = offset_mapping
|
||||
inputs.pop("overflow_to_sample_mapping", None)
|
||||
num_chunks = len(inputs["input_ids"])
|
||||
|
||||
model_inputs["sentence"] = sentence
|
||||
for i in range(num_chunks):
|
||||
if self.framework == "tf":
|
||||
model_inputs = {k: tf.expand_dims(v[i], 0) for k, v in inputs.items()}
|
||||
else:
|
||||
model_inputs = {k: v[i].unsqueeze(0) for k, v in inputs.items()}
|
||||
if offset_mapping is not None:
|
||||
model_inputs["offset_mapping"] = offset_mapping
|
||||
model_inputs["sentence"] = sentence if i == 0 else None
|
||||
model_inputs["is_last"] = i == num_chunks - 1
|
||||
|
||||
return model_inputs
|
||||
yield model_inputs
|
||||
|
||||
def _forward(self, model_inputs):
|
||||
# Forward
|
||||
special_tokens_mask = model_inputs.pop("special_tokens_mask")
|
||||
offset_mapping = model_inputs.pop("offset_mapping", None)
|
||||
sentence = model_inputs.pop("sentence")
|
||||
is_last = model_inputs.pop("is_last")
|
||||
if self.framework == "tf":
|
||||
logits = self.model(model_inputs.data)[0]
|
||||
logits = self.model(**model_inputs)[0]
|
||||
else:
|
||||
output = self.model(**model_inputs)
|
||||
logits = output["logits"] if isinstance(output, dict) else output[0]
|
||||
|
@ -245,38 +285,67 @@ class TokenClassificationPipeline(Pipeline):
|
|||
"special_tokens_mask": special_tokens_mask,
|
||||
"offset_mapping": offset_mapping,
|
||||
"sentence": sentence,
|
||||
"is_last": is_last,
|
||||
**model_inputs,
|
||||
}
|
||||
|
||||
def postprocess(self, model_outputs, aggregation_strategy=AggregationStrategy.NONE, ignore_labels=None):
|
||||
def postprocess(self, all_outputs, aggregation_strategy=AggregationStrategy.NONE, ignore_labels=None):
|
||||
if ignore_labels is None:
|
||||
ignore_labels = ["O"]
|
||||
logits = model_outputs["logits"][0].numpy()
|
||||
sentence = model_outputs["sentence"]
|
||||
input_ids = model_outputs["input_ids"][0]
|
||||
offset_mapping = model_outputs["offset_mapping"][0] if model_outputs["offset_mapping"] is not None else None
|
||||
special_tokens_mask = model_outputs["special_tokens_mask"][0].numpy()
|
||||
all_entities = []
|
||||
for model_outputs in all_outputs:
|
||||
logits = model_outputs["logits"][0].numpy()
|
||||
sentence = all_outputs[0]["sentence"]
|
||||
input_ids = model_outputs["input_ids"][0]
|
||||
offset_mapping = (
|
||||
model_outputs["offset_mapping"][0] if model_outputs["offset_mapping"] is not None else None
|
||||
)
|
||||
special_tokens_mask = model_outputs["special_tokens_mask"][0].numpy()
|
||||
|
||||
maxes = np.max(logits, axis=-1, keepdims=True)
|
||||
shifted_exp = np.exp(logits - maxes)
|
||||
scores = shifted_exp / shifted_exp.sum(axis=-1, keepdims=True)
|
||||
maxes = np.max(logits, axis=-1, keepdims=True)
|
||||
shifted_exp = np.exp(logits - maxes)
|
||||
scores = shifted_exp / shifted_exp.sum(axis=-1, keepdims=True)
|
||||
|
||||
if self.framework == "tf":
|
||||
input_ids = input_ids.numpy()
|
||||
offset_mapping = offset_mapping.numpy() if offset_mapping is not None else None
|
||||
if self.framework == "tf":
|
||||
input_ids = input_ids.numpy()
|
||||
offset_mapping = offset_mapping.numpy() if offset_mapping is not None else None
|
||||
|
||||
pre_entities = self.gather_pre_entities(
|
||||
sentence, input_ids, scores, offset_mapping, special_tokens_mask, aggregation_strategy
|
||||
)
|
||||
grouped_entities = self.aggregate(pre_entities, aggregation_strategy)
|
||||
# Filter anything that is in self.ignore_labels
|
||||
entities = [
|
||||
entity
|
||||
for entity in grouped_entities
|
||||
if entity.get("entity", None) not in ignore_labels
|
||||
and entity.get("entity_group", None) not in ignore_labels
|
||||
]
|
||||
return entities
|
||||
pre_entities = self.gather_pre_entities(
|
||||
sentence, input_ids, scores, offset_mapping, special_tokens_mask, aggregation_strategy
|
||||
)
|
||||
grouped_entities = self.aggregate(pre_entities, aggregation_strategy)
|
||||
# Filter anything that is in self.ignore_labels
|
||||
entities = [
|
||||
entity
|
||||
for entity in grouped_entities
|
||||
if entity.get("entity", None) not in ignore_labels
|
||||
and entity.get("entity_group", None) not in ignore_labels
|
||||
]
|
||||
all_entities.extend(entities)
|
||||
num_chunks = len(all_outputs)
|
||||
if num_chunks > 1:
|
||||
all_entities = self.aggregate_overlapping_entities(all_entities)
|
||||
return all_entities
|
||||
|
||||
def aggregate_overlapping_entities(self, entities):
|
||||
if len(entities) == 0:
|
||||
return entities
|
||||
entities = sorted(entities, key=lambda x: x["start"])
|
||||
aggregated_entities = []
|
||||
previous_entity = entities[0]
|
||||
for entity in entities:
|
||||
if previous_entity["start"] <= entity["start"] < previous_entity["end"]:
|
||||
current_length = entity["end"] - entity["start"]
|
||||
previous_length = previous_entity["end"] - previous_entity["start"]
|
||||
if current_length > previous_length:
|
||||
previous_entity = entity
|
||||
elif current_length == previous_length and entity["score"] > previous_entity["score"]:
|
||||
previous_entity = entity
|
||||
else:
|
||||
aggregated_entities.append(previous_entity)
|
||||
previous_entity = entity
|
||||
aggregated_entities.append(previous_entity)
|
||||
return aggregated_entities
|
||||
|
||||
def gather_pre_entities(
|
||||
self,
|
||||
|
@ -290,9 +359,7 @@ class TokenClassificationPipeline(Pipeline):
|
|||
"""Fuse various numpy arrays into dicts with all the information needed for aggregation"""
|
||||
pre_entities = []
|
||||
for idx, token_scores in enumerate(scores):
|
||||
# Filter special_tokens, they should only occur
|
||||
# at the sentence boundaries since we're not encoding pairs of
|
||||
# sentences so we don't have to keep track of those.
|
||||
# Filter special_tokens
|
||||
if special_tokens_mask[idx]:
|
||||
continue
|
||||
|
||||
|
@ -317,7 +384,10 @@ class TokenClassificationPipeline(Pipeline):
|
|||
AggregationStrategy.AVERAGE,
|
||||
AggregationStrategy.MAX,
|
||||
}:
|
||||
warnings.warn("Tokenizer does not support real words, using fallback heuristic", UserWarning)
|
||||
warnings.warn(
|
||||
"Tokenizer does not support real words, using fallback heuristic",
|
||||
UserWarning,
|
||||
)
|
||||
is_subword = start_ind > 0 and " " not in sentence[start_ind - 1 : start_ind + 1]
|
||||
|
||||
if int(input_ids[idx]) == self.tokenizer.unk_token_id:
|
||||
|
|
|
@ -196,6 +196,127 @@ class TokenClassificationPipelineTests(unittest.TestCase):
|
|||
)
|
||||
self.assertEqual(token_classifier._postprocess_params["aggregation_strategy"], AggregationStrategy.FIRST)
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_chunking(self):
|
||||
NER_MODEL = "elastic/distilbert-base-uncased-finetuned-conll03-english"
|
||||
model = AutoModelForTokenClassification.from_pretrained(NER_MODEL)
|
||||
tokenizer = AutoTokenizer.from_pretrained(NER_MODEL, use_fast=True)
|
||||
tokenizer.model_max_length = 10
|
||||
stride = 5
|
||||
sentence = (
|
||||
"Hugging Face, Inc. is a French company that develops tools for building applications using machine learning. "
|
||||
"The company, based in New York City was founded in 2016 by French entrepreneurs Clément Delangue, Julien Chaumond, and Thomas Wolf."
|
||||
)
|
||||
|
||||
token_classifier = TokenClassificationPipeline(
|
||||
model=model, tokenizer=tokenizer, aggregation_strategy="simple", stride=stride
|
||||
)
|
||||
output = token_classifier(sentence)
|
||||
self.assertEqual(
|
||||
nested_simplify(output),
|
||||
[
|
||||
{"entity_group": "ORG", "score": 0.978, "word": "hugging face, inc.", "start": 0, "end": 18},
|
||||
{"entity_group": "MISC", "score": 0.999, "word": "french", "start": 24, "end": 30},
|
||||
{"entity_group": "LOC", "score": 0.997, "word": "new york city", "start": 131, "end": 144},
|
||||
{"entity_group": "MISC", "score": 0.999, "word": "french", "start": 168, "end": 174},
|
||||
{"entity_group": "PER", "score": 0.999, "word": "clement delangue", "start": 189, "end": 205},
|
||||
{"entity_group": "PER", "score": 0.999, "word": "julien chaumond", "start": 207, "end": 222},
|
||||
{"entity_group": "PER", "score": 0.999, "word": "thomas wolf", "start": 228, "end": 239},
|
||||
],
|
||||
)
|
||||
|
||||
token_classifier = TokenClassificationPipeline(
|
||||
model=model, tokenizer=tokenizer, aggregation_strategy="first", stride=stride
|
||||
)
|
||||
output = token_classifier(sentence)
|
||||
self.assertEqual(
|
||||
nested_simplify(output),
|
||||
[
|
||||
{"entity_group": "ORG", "score": 0.978, "word": "hugging face, inc.", "start": 0, "end": 18},
|
||||
{"entity_group": "MISC", "score": 0.999, "word": "french", "start": 24, "end": 30},
|
||||
{"entity_group": "LOC", "score": 0.997, "word": "new york city", "start": 131, "end": 144},
|
||||
{"entity_group": "MISC", "score": 0.999, "word": "french", "start": 168, "end": 174},
|
||||
{"entity_group": "PER", "score": 0.999, "word": "clement delangue", "start": 189, "end": 205},
|
||||
{"entity_group": "PER", "score": 0.999, "word": "julien chaumond", "start": 207, "end": 222},
|
||||
{"entity_group": "PER", "score": 0.999, "word": "thomas wolf", "start": 228, "end": 239},
|
||||
],
|
||||
)
|
||||
|
||||
token_classifier = TokenClassificationPipeline(
|
||||
model=model, tokenizer=tokenizer, aggregation_strategy="max", stride=stride
|
||||
)
|
||||
output = token_classifier(sentence)
|
||||
self.assertEqual(
|
||||
nested_simplify(output),
|
||||
[
|
||||
{"entity_group": "ORG", "score": 0.978, "word": "hugging face, inc.", "start": 0, "end": 18},
|
||||
{"entity_group": "MISC", "score": 0.999, "word": "french", "start": 24, "end": 30},
|
||||
{"entity_group": "LOC", "score": 0.997, "word": "new york city", "start": 131, "end": 144},
|
||||
{"entity_group": "MISC", "score": 0.999, "word": "french", "start": 168, "end": 174},
|
||||
{"entity_group": "PER", "score": 0.999, "word": "clement delangue", "start": 189, "end": 205},
|
||||
{"entity_group": "PER", "score": 0.999, "word": "julien chaumond", "start": 207, "end": 222},
|
||||
{"entity_group": "PER", "score": 0.999, "word": "thomas wolf", "start": 228, "end": 239},
|
||||
],
|
||||
)
|
||||
|
||||
token_classifier = TokenClassificationPipeline(
|
||||
model=model, tokenizer=tokenizer, aggregation_strategy="average", stride=stride
|
||||
)
|
||||
output = token_classifier(sentence)
|
||||
self.assertEqual(
|
||||
nested_simplify(output),
|
||||
[
|
||||
{"entity_group": "ORG", "score": 0.978, "word": "hugging face, inc.", "start": 0, "end": 18},
|
||||
{"entity_group": "MISC", "score": 0.999, "word": "french", "start": 24, "end": 30},
|
||||
{"entity_group": "LOC", "score": 0.997, "word": "new york city", "start": 131, "end": 144},
|
||||
{"entity_group": "MISC", "score": 0.999, "word": "french", "start": 168, "end": 174},
|
||||
{"entity_group": "PER", "score": 0.999, "word": "clement delangue", "start": 189, "end": 205},
|
||||
{"entity_group": "PER", "score": 0.999, "word": "julien chaumond", "start": 207, "end": 222},
|
||||
{"entity_group": "PER", "score": 0.999, "word": "thomas wolf", "start": 228, "end": 239},
|
||||
],
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_chunking_fast(self):
|
||||
# Note: We cannot run the test on "conflicts" on the chunking.
|
||||
# The problem is that the model is random, and thus the results do heavily
|
||||
# depend on the chunking, so we cannot expect "abcd" and "bcd" to find
|
||||
# the same entities. We defer to slow tests for this.
|
||||
pipe = pipeline(model="hf-internal-testing/tiny-bert-for-token-classification")
|
||||
sentence = "The company, based in New York City was founded in 2016 by French entrepreneurs"
|
||||
|
||||
results = pipe(sentence, aggregation_strategy="first")
|
||||
# This is what this random model gives on the full sentence
|
||||
self.assertEqual(
|
||||
nested_simplify(results),
|
||||
[
|
||||
# This is 2 actual tokens
|
||||
{"end": 39, "entity_group": "MISC", "score": 0.115, "start": 31, "word": "city was"},
|
||||
{"end": 79, "entity_group": "MISC", "score": 0.115, "start": 66, "word": "entrepreneurs"},
|
||||
],
|
||||
)
|
||||
|
||||
# This will force the tokenizer to split after "city was".
|
||||
pipe.tokenizer.model_max_length = 12
|
||||
self.assertEqual(
|
||||
pipe.tokenizer.decode(pipe.tokenizer.encode(sentence, truncation=True)),
|
||||
"[CLS] the company, based in new york city was [SEP]",
|
||||
)
|
||||
|
||||
stride = 4
|
||||
results = pipe(sentence, aggregation_strategy="first", stride=stride)
|
||||
self.assertEqual(
|
||||
nested_simplify(results),
|
||||
[
|
||||
{"end": 39, "entity_group": "MISC", "score": 0.115, "start": 31, "word": "city was"},
|
||||
# This is an extra entity found by this random model, but at least both original
|
||||
# entities are there
|
||||
{"end": 58, "entity_group": "MISC", "score": 0.115, "start": 56, "word": "by"},
|
||||
{"end": 79, "entity_group": "MISC", "score": 0.115, "start": 66, "word": "entrepreneurs"},
|
||||
],
|
||||
)
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
def test_spanish_bert(self):
|
||||
|
|
Loading…
Reference in New Issue